Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e19b0897
".github/actions/vscode:/vscode.git/clone" did not exist on "cca59ce3a289d8837d81747389f997d6a1564d1d"
Unverified
Commit
e19b0897
authored
Jan 02, 2025
by
Muhammed Emin Ozturk
Committed by
GitHub
Jan 02, 2025
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
d04cb754
1d8e4ec2
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1125 additions
and
334 deletions
+1125
-334
CMakeLists.txt
CMakeLists.txt
+1
-1
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+2
-0
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+82
-0
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
+253
-0
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
+8
-8
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
+303
-0
example/01_gemm/gemm_xdl_fp16_v3.cpp
example/01_gemm/gemm_xdl_fp16_v3.cpp
+10
-10
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+0
-82
example/01_gemm/run_gemm_example_streamk_v2.inc
example/01_gemm/run_gemm_example_streamk_v2.inc
+0
-82
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+0
-82
include/ck/library/utility/host_tensor.hpp
include/ck/library/utility/host_tensor.hpp
+55
-10
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+30
-0
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+11
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+189
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+77
-27
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+45
-5
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+52
-22
No files found.
CMakeLists.txt
View file @
e19b0897
...
@@ -585,7 +585,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
...
@@ -585,7 +585,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
)
)
add_subdirectory
(
example
)
add_subdirectory
(
example
)
if
(
BUILD_TESTING
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
endif
()
endif
()
endif
()
endif
()
...
...
cmake/EnableCompilerWarnings.cmake
View file @
e19b0897
...
@@ -66,7 +66,7 @@ else()
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunreachable-code
-Wunused
-Wunused
-Wno-reserved-identifier
-Wno-reserved-identifier
-Werror
-Werror
-Wno-option-ignored
-Wno-option-ignored
-Wsign-compare
-Wsign-compare
-Wno-extra-semi-stmt
-Wno-extra-semi-stmt
...
...
example/01_gemm/CMakeLists.txt
View file @
e19b0897
...
@@ -29,6 +29,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
...
@@ -29,6 +29,8 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_v3
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8_v3
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_executable
(
example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_bf16_v3
)
...
...
example/01_gemm/common.hpp
View file @
e19b0897
...
@@ -287,3 +287,85 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
...
@@ -287,3 +287,85 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
return
true
;
return
true
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1e-1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1e-6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1e-3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5e-2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1e-1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1e-3
;
}
}
example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp
0 → 100644
View file @
e19b0897
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using
ADataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
pk_i4_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteB
=
true
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
;
// clang-format off
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
128
,
16
,
64
,
KPerBlock
,
8
,
32
,
16
,
16
,
1
,
2
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
ADataType
,
ADataType
,
PermuteA
,
PermuteB
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
KBatch
=
problem_size
.
KBatch
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
3
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
if
constexpr
(
PermuteB
)
{
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
// int K0, N, K1
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
else
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
}
}
}
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmV2Instance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
float
ave_time
=
0
;
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
});
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
}
if
(
config
.
time_kernel
)
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
20
,
50
,
true
,
50
});
std
::
size_t
flop
=
2
_uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
/
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
BDataType
>
,
ck
::
pk_i4_t
>
?
2
:
1
)
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
bool
run_gemm_splitk_example
(
int
argc
,
char
*
argv
[])
{
ProblemSizeSplitK
problem_size
;
ExecutionConfig
config
;
return
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
&&
run_gemm
(
problem_size
,
config
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_splitk_example
(
argc
,
argv
);
}
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
View file @
e19b0897
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using
ADataType
=
ck
::
f
8
_t
;
using
ADataType
=
ck
::
hal
f_t
;
using
BDataType
=
ck
::
hal
f_t
;
using
BDataType
=
ck
::
f
8
_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
...
@@ -29,15 +29,15 @@ using DeviceGemmV2Instance =
...
@@ -29,15 +29,15 @@ using DeviceGemmV2Instance =
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
64
,
16
,
16
,
16
,
16
,
64
,
16
,
8
,
256
,
8
,
16
,
16
,
16
,
16
,
16
,
1
,
1
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
16
,
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Intr
a
wave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
ck
::
BlockGemmPipelineScheduler
::
Int
e
rwave
,
ck
::
BlockGemmPipelineVersion
::
v1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
...
...
example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp
0 → 100644
View file @
e19b0897
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
pk_i4_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteB
=
true
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
;
// clang-format off
using
DeviceGemmV2Instance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffleV3
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
128
,
16
,
128
,
KPerBlock
,
8
,
32
,
16
,
16
,
1
,
4
,
S
<
16
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Interwave
,
ck
::
BlockGemmPipelineVersion
::
v2
,
ADataType
,
ADataType
,
PermuteA
,
PermuteB
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
using
namespace
ck
::
literals
;
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
KBatch
=
problem_size
.
KBatch
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
ck
::
index_t
stride
,
auto
layout
)
{
if
(
stride
==
-
1
)
{
// give a chance if stride is -1, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
static_cast
<
std
::
size_t
>
(
col
);
}
else
{
return
static_cast
<
std
::
size_t
>
(
row
);
}
}
else
return
static_cast
<
std
::
size_t
>
(
stride
);
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
{
case
0
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
break
;
case
3
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
BDataType
>
{
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
2
,
2
});
}
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// weight permute
if
constexpr
(
PermuteB
)
{
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
// int K0, N, K1
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
else
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
}
}
}
// vector pk_i4x4 permute
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
{
int
input
[
8
];
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
int
i4x2
=
b_k_n_permute
(
j
+
k
*
2
,
i
).
data
;
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
// permute 01234567->20643175
{
int
hi
=
input
[
2
];
int
lo
=
input
[
0
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
0
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
6
];
int
lo
=
input
[
4
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
2
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
3
];
int
lo
=
input
[
1
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
4
,
i
)
=
i4x2
;
}
{
int
hi
=
input
[
7
];
int
lo
=
input
[
5
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_permute
(
j
+
6
,
i
)
=
i4x2
;
}
}
}
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmV2Instance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
float
ave_time
=
0
;
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
KBatch
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
});
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results!"
,
get_rtol
<
CDataType
>
(),
get_atol
<
CDataType
>
());
}
if
(
config
.
time_kernel
)
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
20
,
50
,
true
,
50
});
std
::
size_t
flop
=
2
_uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
/
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
BDataType
>
,
ck
::
pk_i4_t
>
?
2
:
1
)
+
sizeof
(
CDataType
)
*
M
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
bool
run_gemm_splitk_example
(
int
argc
,
char
*
argv
[])
{
ProblemSizeSplitK
problem_size
;
ExecutionConfig
config
;
return
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
&&
run_gemm
(
problem_size
,
config
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_splitk_example
(
argc
,
argv
);
}
example/01_gemm/gemm_xdl_fp16_v3.cpp
View file @
e19b0897
...
@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t;
...
@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
using
BLayout
=
Col
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
...
@@ -27,17 +27,17 @@ using DeviceGemmV2Instance =
...
@@ -27,17 +27,17 @@ using DeviceGemmV2Instance =
ALayout
,
BLayout
,
CLayout
,
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
25
6
,
6
4
,
224
,
25
6
,
16
,
1
6
,
6
4
,
8
,
2
,
25
6
,
8
,
8
,
16
,
16
,
16
,
16
,
7
,
8
,
1
,
1
,
S
<
8
,
3
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
2
,
8
,
8
,
0
,
S
<
8
,
3
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
32
,
2
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
1
,
8
,
2
,
0
,
2
,
8
,
8
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
ck
::
BlockGemmPipelineScheduler
::
Intr
a
wave
,
ck
::
BlockGemmPipelineVersion
::
v
3
>
;
ck
::
BlockGemmPipelineScheduler
::
Int
e
rwave
,
ck
::
BlockGemmPipelineVersion
::
v
2
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/run_gemm_example.inc
View file @
e19b0897
...
@@ -5,88 +5,6 @@
...
@@ -5,88 +5,6 @@
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
2
e
-
1
;
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
2
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
2
e
-
1
;
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
ProblemType
>
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
...
...
example/01_gemm/run_gemm_example_streamk_v2.inc
View file @
e19b0897
...
@@ -3,88 +3,6 @@
...
@@ -3,88 +3,6 @@
#pragma once
#pragma once
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
ProblemType
>
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
...
...
example/01_gemm/run_gemm_example_v2.inc
View file @
e19b0897
...
@@ -3,88 +3,6 @@
...
@@ -3,88 +3,6 @@
#pragma once
#pragma once
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_rtol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
1
e
-
1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
1.5e-1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
DataType
>
inline
__host__
__device__
constexpr
double
get_atol
()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
float
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
double
>
)
{
return
1
e
-
6
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
half_t
>
)
{
return
1
e
-
3
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bhalf_t
>
)
{
return
5
e
-
2
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int32_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
int8_t
>
)
{
return
1
e
-
1
;
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
f8_t
>
)
{
return
16.1
;
// 240 and 224 are acceptable
}
else
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck
::
bf8_t
>
)
{
return
8192.1
;
// 57344 and 49152 are acceptable
}
else
{
return
1
e
-
3
;
}
}
template
<
typename
ProblemType
>
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
...
...
include/ck/library/utility/host_tensor.hpp
View file @
e19b0897
...
@@ -266,18 +266,18 @@ struct Tensor
...
@@ -266,18 +266,18 @@ struct Tensor
using
Data
=
std
::
vector
<
T
>
;
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
template
<
typename
X
>
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
{
}
}
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
Tensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
GetElementSpaceSize
())
:
mDesc
(
lens
,
strides
),
mData
(
GetElementSpaceSize
())
{
{
}
}
template
<
typename
Lengths
>
template
<
typename
Lengths
>
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
GetElementSpaceSize
())
Tensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
GetElementSpaceSize
())
{
{
}
}
...
@@ -287,7 +287,7 @@ struct Tensor
...
@@ -287,7 +287,7 @@ struct Tensor
{
{
}
}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpaceSize
())
{}
Tensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
GetElementSpaceSize
())
{}
template
<
typename
OutT
>
template
<
typename
OutT
>
Tensor
<
OutT
>
CopyAsType
()
const
Tensor
<
OutT
>
CopyAsType
()
const
...
@@ -322,7 +322,17 @@ struct Tensor
...
@@ -322,7 +322,17 @@ struct Tensor
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSize
()
const
{
return
mDesc
.
GetElementSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
return
mDesc
.
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSize
()
const
{
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
(
mDesc
.
GetElementSpaceSize
()
+
1
)
/
2
;
}
else
{
return
mDesc
.
GetElementSpaceSize
();
}
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
std
::
size_t
GetElementSpaceSizeInBytes
()
const
{
return
sizeof
(
T
)
*
GetElementSpaceSize
();
}
...
@@ -469,29 +479,64 @@ struct Tensor
...
@@ -469,29 +479,64 @@ struct Tensor
template
<
typename
...
Is
>
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
;
}
else
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
}
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
T
&
operator
()(
Is
...
is
)
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
const
T
&
operator
()(
Is
...
is
)
const
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
}
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
if
constexpr
(
ck
::
is_same_v
<
ck
::
remove_cvref_t
<
T
>
,
ck
::
pk_i4_t
>
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)
/
2
];
}
else
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
...
...
include/ck/library/utility/host_tensor_generator.hpp
View file @
e19b0897
...
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
...
@@ -81,6 +81,20 @@ struct GeneratorTensor_1<int8_t>
}
}
};
};
template
<
>
struct
GeneratorTensor_1
<
ck
::
pk_i4_t
>
{
int8_t
value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
t
=
value
+
8
;
ck
::
pk_i4_t
r
=
((
t
<<
4
)
+
t
)
&
0xff
;
return
r
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_2
struct
GeneratorTensor_2
{
{
...
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
...
@@ -121,6 +135,22 @@ struct GeneratorTensor_2<int8_t>
}
}
};
};
template
<
>
struct
GeneratorTensor_2
<
ck
::
pk_i4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
pk_i4_t
operator
()(
Is
...)
{
int
hi
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
int
lo
=
std
::
rand
()
%
(
max_value
-
min_value
)
+
min_value
+
8
;
ck
::
pk_i4_t
r
=
((
hi
<<
4
)
+
lo
)
&
0xff
;
return
r
;
}
};
#if defined CK_ENABLE_FP8
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
struct
GeneratorTensor_2
<
ck
::
f8_t
>
...
...
include/ck/tensor/static_tensor.hpp
View file @
e19b0897
...
@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -167,7 +167,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
template
<
typename
X
,
typename
Idx
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
...
@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -201,7 +201,7 @@ struct StaticTensorTupleOfVectorBuffer
// Idx is for S, not X. Idx should be aligned with X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
template
<
typename
X
,
typename
Idx
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
e19b0897
...
@@ -36,6 +36,10 @@ struct DeviceGemmV2 : public BaseOperator
...
@@ -36,6 +36,10 @@ struct DeviceGemmV2 : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteA
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
e19b0897
...
@@ -64,7 +64,9 @@ template <typename ALayout,
...
@@ -64,7 +64,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
struct
DeviceGemm_Xdl_CShuffleV3
:
public
DeviceGemmV2
<
ALayout
,
struct
DeviceGemm_Xdl_CShuffleV3
:
public
DeviceGemmV2
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
...
@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -122,7 +124,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipeSched
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
,
PermuteA
,
PermuteB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
...
@@ -633,6 +637,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -633,6 +637,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
index_t
GetKPerBlock
()
override
{
return
KPerBlock
;
}
bool
GetPermuteA
()
override
{
return
PermuteA
;
}
bool
GetPermuteB
()
override
{
return
PermuteB
;
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
CDataType
*
p_c
,
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
e19b0897
...
@@ -7,12 +7,177 @@
...
@@ -7,12 +7,177 @@
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include <cassert>
#include <cassert>
namespace
ck
{
namespace
ck
{
// Fast int4x4 to half8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__
__device__
inline
half4_t
pki4_to_half4
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Extract the two int4 at low bit and create two fp16 number.
int
lo
=
amd_assembly_and_or_b32
(
q
,
LO
,
EX
);
// Extract the two int4 at hight bit and create two fp16 number.
int
hi
=
amd_assembly_and_or_b32
(
q
,
HI
,
EX
);
const
int
SUB
=
0xE408E408
;
// half2 {-1032, -1032}
const
int
MUL
=
0x2c002c00
;
// half2 {1 / 16, 1 / 16}
const
int
ADD
=
0xd480d480
;
// half2 {-72, -72}
vector_type
<
half_t
,
4
>
res
;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
res
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
res
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
amd_assembly_pk_fma_f16
(
bit_cast
<
half2_t
>
(
hi
),
bit_cast
<
half2_t
>
(
MUL
),
bit_cast
<
half2_t
>
(
ADD
));
return
res
.
template
AsType
<
half4_t
>()[
Number
<
0
>
{}];
}
__host__
__device__
inline
half2_t
pki4_to_half2
(
pk_i4_t
q
)
{
#if 1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
uint32_t
i4s
=
((
x_u8
&
0x0f
)
<<
16
)
|
((
x_u8
&
0xf0
)
>>
4
);
const
int
EX
=
0x64006400
;
const
int
SUB
=
0xE408E408
;
//-8
int
lo
=
i4s
|
EX
;
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
#else
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
vector_type
<
half_t
,
2
>
res
;
half_t
x_h
=
(
x_u8
&
0x0f
)
-
8
;
half_t
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8
;
res
.
template
AsType
<
half_t
>()(
Number
<
0
>
{})
=
x_l
;
res
.
template
AsType
<
half_t
>()(
Number
<
1
>
{})
=
x_h
;
return
res
.
template
AsType
<
half2_t
>()[
Number
<
0
>
{}];
#endif
}
__host__
__device__
inline
bhalf4_t
pki4_to_bhalf4
(
int
q
)
{
uint32_t
i8s
=
(
q
&
0xf
)
|
((
q
&
0xf0
)
<<
4
)
|
((
q
&
0xf00
)
<<
8
)
|
((
q
&
0xf000
)
<<
12
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
i8s
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388616.
f
;
fp32_intermediates
[
1
]
-=
8388616.
f
;
fp32_intermediates
[
2
]
-=
8388616.
f
;
fp32_intermediates
[
3
]
-=
8388616.
f
;
vector_type
<
bhalf_t
,
4
>
res
;
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
1
],
fp32_intermediates_casted
[
0
],
0x7632
));
res
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
bit_cast
<
bhalf2_t
>
(
__byte_perm
(
fp32_intermediates_casted
[
3
],
fp32_intermediates_casted
[
2
],
0x7632
));
return
res
.
template
AsType
<
bhalf4_t
>()[
Number
<
0
>
{}];
}
__host__
__device__
inline
bhalf2_t
pki4_to_bhalf2
(
pk_i4_t
q
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
float
x_h
=
((
x_u8
&
0x0f
)
>>
0
)
-
8.
f
;
float
x_l
=
((
x_u8
&
0xf0
)
>>
4
)
-
8.
f
;
vector_type
<
bhalf_t
,
2
>
res
;
res
.
template
AsType
<
bhalf_t
>()(
Number
<
0
>
{})
=
type_convert
<
bhalf_t
>
(
x_l
);
res
.
template
AsType
<
bhalf_t
>()(
Number
<
1
>
{})
=
type_convert
<
bhalf_t
>
(
x_h
);
return
res
.
template
AsType
<
bhalf2_t
>()[
Number
<
0
>
{}];
}
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
struct
PassThroughPack8
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
half8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
#if 1
vector_type
<
half_t
,
8
>
result
;
result
.
template
AsType
<
half4_t
>()(
Number
<
0
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
half4_t
>()(
Number
<
1
>
{})
=
pki4_to_half4
(
bit_cast
<
int
>
(
x
)
>>
8
);
y
=
result
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf8_t
&
y
,
const
ck
::
pk_i4x4_t
&
x
)
const
{
#if 1
vector_type
<
bhalf_t
,
8
>
result
;
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
));
result
.
template
AsType
<
bhalf4_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf4
(
bit_cast
<
int
>
(
x
)
>>
16
);
y
=
result
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#else
vector_type
<
bhalf_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
0
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
1
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
2
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
bhalf2_t
>()(
Number
<
3
>
{})
=
pki4_to_bhalf2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
bhalf8_t
>()[
Number
<
0
>
{}];
#endif
}
constexpr
const
static
bool
is_pack8_invocable
=
true
;
};
#pragma clang diagnostic push
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct
UnaryOpBase
struct
UnaryOpBase
...
@@ -49,6 +214,24 @@ struct PassThroughPack2
...
@@ -49,6 +214,24 @@ struct PassThroughPack2
auto
t
=
type_convert
<
float2_t
>
(
x
);
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
y
=
type_convert
<
half2_t
>
(
t
);
}
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
pk_i4_t
&
x
)
const
{
#if 1
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f16
=
ck
::
type_convert
<
ck
::
half_t
>
(
x_l
);
auto
h_f16
=
ck
::
type_convert
<
ck
::
half_t
>
(
x_h
);
y
=
{
l_f16
,
h_f16
};
#else
uint32_t
t
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
y
=
ck
::
bit_cast
<
half2_t
>
(
t
);
#endif
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
};
...
@@ -76,6 +259,12 @@ struct PassThrough final : public UnaryOpBase
...
@@ -76,6 +259,12 @@ struct PassThrough final : public UnaryOpBase
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
pk_i4_t
,
pk_i4_t
>
(
pk_i4_t
&
y
,
const
pk_i4_t
&
x
)
const
{
y
=
x
;
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
e19b0897
...
@@ -127,7 +127,9 @@ template <typename ALayout,
...
@@ -127,7 +127,9 @@ template <typename ALayout,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
bool
PermuteA
=
false
,
bool
PermuteB
=
false
>
struct
GridwiseGemm_xdl_cshuffle_v3
struct
GridwiseGemm_xdl_cshuffle_v3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -151,6 +153,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -151,6 +153,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
APackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
ADataType
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
static
constexpr
index_t
BPackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
BDataType
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
...
@@ -319,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -319,6 +335,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
static_assert
(
!
(
is_same_v
<
remove_cvref_t
<
ADataType
>
,
pk_i4_t
>
&&
GemmSpec
!=
GemmSpecialization
::
Default
),
"pk_i4_t does not support padding"
);
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
...
@@ -373,15 +393,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -373,15 +393,39 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
else
else
{
{
// not pad N or K
if
constexpr
(
!
PermuteB
)
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
{
b_grid_desc_nraw_kraw
,
// not pad N or K
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
make_pass_through_transform
(
N
)),
b_grid_desc_nraw_kraw
,
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
return
b_grid_desc_bk0_n_bk1
;
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK00
=
BK0_
/
BK01
;
const
auto
b_grid_desc_bk00_n_bk01_bk1_permute
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
BK00
,
N
,
BK01
,
BK1Value
));
const
auto
b_grid_desc_bk0_n_bk1_permute
=
transform_tensor_descriptor
(
b_grid_desc_bk00_n_bk01_bk1_permute
,
make_tuple
(
make_merge_transform
(
make_tuple
(
BK00
,
BK01
)),
make_pass_through_transform
(
make_tuple
(
N
)),
make_pass_through_transform
(
BK1Value
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_grid_desc_bk0_n_bk1_permute
;
}
}
}
}
}
...
@@ -572,7 +616,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -572,7 +616,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
{
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
APackedSize
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
{
...
@@ -585,7 +629,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -585,7 +629,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
if
constexpr
(
!
PermuteB
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
/
BPackedSize
;
}
else
{
const
int
k0_offset
=
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
k0_offset
/
BPackedSize
;
}
}
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
...
@@ -625,9 +677,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -625,9 +677,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
// in some cases.
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
constexpr
index_t
LdsSize
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
/
APackedSize
;
?
1
constexpr
auto
MLdsLayer
=
LdsSize
<
1
?
1
:
LdsSize
;
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
...
@@ -761,10 +812,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -761,10 +812,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
// NLdsLayer * K0 as logical Bank
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
constexpr
index_t
LdsSize
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
/
BPackedSize
;
?
1
constexpr
index_t
NLdsLayer
=
LdsSize
<
1
?
1
:
LdsSize
;
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
...
@@ -946,8 +995,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -946,8 +995,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
APackedSize
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)),
b_block_space_size_aligned
*
sizeof
(
BDataType
)
/
BPackedSize
),
c_block_size
*
sizeof
(
CShuffleDataType
));
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
...
@@ -1312,8 +1361,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1312,8 +1361,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared
)
+
reinterpret_cast
<
BDataType
*>
(
static_cast
<
char
*>
(
p_shared
)
+
a_block_space_size_aligned
*
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
sizeof
(
ADataType
)
/
APackedSize
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
...
@@ -1706,16 +1756,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1706,16 +1756,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_cast
<
ADataType
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ADataType
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static
_cast
<
BDataType
*>
(
p_shared_0
)
+
bit
_cast
<
BDataType
*>
(
static_cast
<
char
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ADataType
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static
_cast
<
BDataType
*>
(
p_shared_1
)
+
bit
_cast
<
BDataType
*>
(
bit_cast
<
char
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
e19b0897
...
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
{
...
@@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! Not divisible"
);
"wrong! Not divisible"
);
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
SrcScalarPerVector
%
PackedSize
==
0
,
"pk data N cannot be 1"
);
}
}
}
template
<
typename
SrcRefToOriginDisplacement
,
template
<
typename
SrcRefToOriginDisplacement
,
...
@@ -1109,7 +1121,7 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1109,7 +1121,7 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
...
@@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
{
...
@@ -1133,9 +1146,36 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1133,9 +1146,36 @@ struct ThreadwiseTensorSliceTransfer_v4
});
});
}
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
)
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
{
SrcScalarPerVector
%
2
==
0
)
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
e19b0897
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename
DstDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
_
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
_
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
...
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
static
constexpr
auto
SrcScalarPerVector
=
Number
<
SrcScalarPerVector_
/
PackedSize
>
{};
static
constexpr
auto
DstScalarPerVector
=
Number
<
DstScalarPerVector_
/
PackedSize
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
...
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_
(
src_element_op
),
src_element_op_
(
src_element_op
),
dst_element_op_
(
dst_element_op
)
dst_element_op_
(
dst_element_op
)
{
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>
,
"SrcData != DstData"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"
);
static_assert
(
SrcVectorDim
==
DstVectorDim
,
"pk_i4_t does not support transpose"
);
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
(
SrcScalarPerVector
_
)
==
0
,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
...
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
true
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
dst_vector_type
op_r_v
;
...
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
else
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
else
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
}
return
1
;
else
{
return
1
;
}
};
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
...
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
true
)};
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
...
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
});
#else
#else
// OOB Check
// OOB Check
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
{
static_assert
(
!
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
,
"in-register transpose is not supported for pk_i4_t"
);
// each transpose does
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
}
else
else
{
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
constexpr
auto
packed_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
packed_access_lengths
=
SliceLengths
{}
/
packed_per_access
;
static_ford
<
decltype
(
packed_access_lengths
)
>
{}([
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
});
}
}
...
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim
// src scalar per access on each dim
// TODO: don't use this
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
});
// copy data from dst_vector_container to dst_buf
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
dst_coord_
.
GetOffset
()
/
PackedSize
,
is_dst_valid
,
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
...
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
{
// 1st stage of transforms
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment