Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
66e61076
Commit
66e61076
authored
Jul 28, 2023
by
aska-0096
Browse files
Sanity pass.
parent
0c51a35e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
359 additions
and
35 deletions
+359
-35
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
+1
-1
example/49_fpAintB_gemm/run_gemm_example.inc
example/49_fpAintB_gemm/run_gemm_example.inc
+46
-2
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
...ensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
+87
-17
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+49
-4
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+65
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+111
-1
No files found.
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
View file @
66e61076
...
@@ -9,7 +9,7 @@ using ADataType = ck::half_t;
...
@@ -9,7 +9,7 @@ using ADataType = ck::half_t;
using
BDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
ScaleDataType
=
ck
::
half_t
;
using
ScaleDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
floa
t
;
using
CShuffleDataType
=
ck
::
half_
t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
...
...
example/49_fpAintB_gemm/run_gemm_example.inc
View file @
66e61076
...
@@ -28,7 +28,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -28,7 +28,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
// assume scale tensor is [1, n]
// assume scale tensor is [1, n]
Tensor
<
ScaleDataType
>
scale_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
BLayout
{}));
Tensor
<
ScaleDataType
>
scale_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
0
,
Row
{}));
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
...
@@ -51,7 +51,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -51,7 +51,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
case
4
:
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
1
.
f
,
1
.
f
}(
scale_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ScaleDataType
>
{
2
.
f
,
2
.
f
}(
scale_k_n
);
break
;
break
;
case
5
:
case
5
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
...
@@ -64,6 +64,50 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -64,6 +64,50 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistribution
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
ck
::
utils
::
FillUniformDistribution
<
ScaleDataType
>
{
-
1.
f
,
1.
f
}(
scale_k_n
);
}
}
#if 0
printf
(
"Matrix A:
\n
"
);
for
(
int
im
=
0
;
im
<
M
;
im
++
)
{
for
(
int
ik
=
0
;
ik
<
K
;
ik
++
)
{
if
(
ik
%
16
==
0
){
printf
(
"|"
);
}
printf
(
" %04x"
,
*
(
reinterpret_cast
<
uint16_t
*>
(
&
a_m_k
(
im
,
ik
))));
}
printf
(
"
\n
"
);
}
printf
(
"Matrix B:
\n
"
);
for
(
int
in
=
0
;
in
<
N
;
in
++
)
{
for
(
int
ik
=
0
;
ik
<
K
;
ik
++
)
{
if
(
ik
%
16
==
0
){
printf
(
"|"
);
}
printf
(
" %02x"
,
b_k_n
(
ik
,
in
));
}
printf
(
"
\n
"
);
}
printf
(
"Matrix Scale:
\n
"
);
for
(
int
in
=
0
;
in
<
N
;
in
++
)
{
for
(
int
ik
=
0
;
ik
<
K
;
ik
++
)
{
if
(
ik
%
16
==
0
){
printf
(
"|"
);
}
printf
(
" %04x"
,
*
(
reinterpret_cast
<
uint16_t
*>
(
&
scale_k_n
(
ik
,
in
))));
}
printf
(
"
\n
"
);
}
#endif
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
...
...
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
View file @
66e61076
...
@@ -309,7 +309,8 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -309,7 +309,8 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
auto
scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
auto
scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
scale_thread_desc_
.
GetElementSpaceSize
());
scale_thread_desc_
.
GetElementSpaceSize
());
auto
converted_b_thread_buf
=
b_thread_buf
;
auto
converted_b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// basic intrinsic to determine loopover direction
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
if
constexpr
(
MRepeat
<
NRepeat
)
...
@@ -345,7 +346,7 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -345,7 +346,7 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_buf
);
scale_thread_buf
);
// convert B from int8 to fp16, multiply scale
// convert B from int8 to fp16, multiply scale
static_for
<
0
,
b_thread_buf
.
s
ize
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
b_thread_buf
.
S
ize
(),
1
>
{}([
&
](
auto
i
)
{
converted_b_thread_buf
(
i
)
=
converted_b_thread_buf
(
i
)
=
scale_thread_buf
[
i
/
WmmaK
]
*
scale_thread_buf
[
i
/
WmmaK
]
*
type_convert
<
ADataType
>
(
b_thread_buf
[
i
]);
type_convert
<
ADataType
>
(
b_thread_buf
[
i
]);
...
@@ -390,6 +391,20 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -390,6 +391,20 @@ struct Blockwise_fpAintB_GemmWMMA
else
else
{
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read weight scale
scale_thread_copy_
.
Run
(
scale_block_desc_1_n0_n1_n2_1
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_block_buf
,
scale_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_thread_buf
);
#if 0
printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n",
get_thread_local_1d_id(), n0.value,
*(reinterpret_cast<const uint16_t*>(&scale_thread_buf[n0]))
);
#endif
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ..
// k=0,kpack*1, ..
...
@@ -401,15 +416,6 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -401,15 +416,6 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// read weight scale
scale_thread_copy_
.
Run
(
scale_block_desc_1_n0_n1_n2_1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
scale_block_buf
,
scale_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_thread_buf
);
// convert B from int8 to fp16, multiply scale
// convert B from int8 to fp16, multiply scale
static_for
<
0
,
b_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
b_thread_buf
.
Size
(),
1
>
{}([
&
](
auto
i
)
{
converted_b_thread_buf
(
i
)
=
scale_thread_buf
[
i
/
WmmaK
]
*
converted_b_thread_buf
(
i
)
=
scale_thread_buf
[
i
/
WmmaK
]
*
...
@@ -423,7 +429,71 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -423,7 +429,71 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
if
(
true
){
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<15>{}]))
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, b_thread_buf: %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
b_thread_buf[Number<0>{}],
b_thread_buf[Number<1>{}],
b_thread_buf[Number<2>{}],
b_thread_buf[Number<3>{}],
b_thread_buf[Number<4>{}],
b_thread_buf[Number<5>{}],
b_thread_buf[Number<6>{}],
b_thread_buf[Number<7>{}],
b_thread_buf[Number<8>{}],
b_thread_buf[Number<9>{}],
b_thread_buf[Number<10>{}],
b_thread_buf[Number<11>{}],
b_thread_buf[Number<12>{}],
b_thread_buf[Number<13>{}],
b_thread_buf[Number<14>{}],
b_thread_buf[Number<15>{}]
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, converted_b_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<15>{}]))
);
#endif
}
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
vector_type
<
ADataType
,
WmmaK
>
b_thread_vec
;
vector_type
<
ADataType
,
WmmaK
>
b_thread_vec
;
...
@@ -497,7 +567,7 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -497,7 +567,7 @@ struct Blockwise_fpAintB_GemmWMMA
I1
,
I1
,
Number
<
B_KRow
>
{},
Number
<
B_KRow
>
{},
I1
,
I1
,
Number
<
B_K1
>
{}
),
I1
),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
...
@@ -587,11 +657,11 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -587,11 +657,11 @@ struct Blockwise_fpAintB_GemmWMMA
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_block_desc_1_n0_n1_n2_1
),
decltype
(
scale_block_desc_1_n0_n1_n2_1
),
decltype
(
scale_thread_desc_
),
decltype
(
scale_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K
1
>
,
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
B_K
1
,
1
,
B_K
1
>
;
1
>
;
};
};
template
<
>
template
<
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
66e61076
...
@@ -182,8 +182,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -182,8 +182,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
// When K = 1, it might be scale tensor.
assert
(
K
%
K1
==
0
);
assert
(
K
%
K1
==
0
&&
K
!=
1
);
if
constexpr
(
BEnableLds
)
if
constexpr
(
BEnableLds
)
{
{
...
@@ -216,6 +215,52 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -216,6 +215,52 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
}
}
}
}
static
auto
MakeScaleGridDescriptor
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
=
0
)
{
// assume Scale is [1, N]
const
auto
scale_grid_desc_n_k
=
[
&
]()
{
const
auto
scale_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
return
matrix_padder
.
PadBDescriptor_N_K
(
scale_grid_desc_nraw_kraw
);
}();
const
auto
N
=
scale_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
scale_grid_desc_n_k
.
GetLength
(
I1
);
// When K = 1, it might be scale tensor.
assert
(
K
%
K1
==
0
&&
K
!=
1
);
if
constexpr
(
BEnableLds
)
{
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
scale_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
1
)),
// Reduce K1 = 1
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
scale_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
...
@@ -237,7 +282,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -237,7 +282,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc
=
decltype
(
MakeBGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc
=
decltype
(
MakeBGridDescriptor
(
1
,
1
,
1
));
using
ScaleGridDesc
=
decltype
(
Make
B
GridDescriptor
(
1
,
1
,
1
));
using
ScaleGridDesc
=
decltype
(
Make
Scale
GridDescriptor
(
1
,
1
,
0
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
...
@@ -330,7 +375,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -330,7 +375,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{
{
a_grid_desc_
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
a_grid_desc_
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
b_grid_desc_
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
scale_grid_desc_
=
DeviceOp
::
Make
B
GridDescriptor
(
K
,
N
,
0
);
scale_grid_desc_
=
DeviceOp
::
Make
Scale
GridDescriptor
(
K
,
N
,
0
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
block_2_ctile_map_
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
66e61076
...
@@ -52,6 +52,12 @@ __global__ void
...
@@ -52,6 +52,12 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
defined(__gfx1102__))
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
if
(
false
&&
get_thread_local_1d_id
()
==
0
){
printf
(
"lds_size: %lu
\n
"
,
GridwiseGemm
::
SharedMemTrait
::
lds_size
);
printf
(
"lds_a_size: %d
\n
"
,
GridwiseGemm
::
SharedMemTrait
::
a_block_space_size_aligned
);
printf
(
"lds_b_size: %d
\n
"
,
GridwiseGemm
::
SharedMemTrait
::
b_block_space_size_aligned
);
printf
(
"lds_scale_size: %d
\n
"
,
GridwiseGemm
::
SharedMemTrait
::
scale_block_space_size_aligned
);
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
...
@@ -262,7 +268,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -262,7 +268,7 @@ struct GridwiseFpAintBGemm_Wmma
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K
1
),
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
I
1
),
make_tuple
(
I0
,
I1
,
I0
));
make_tuple
(
I0
,
I1
,
I0
));
}
}
else
else
...
@@ -276,7 +282,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -276,7 +282,7 @@ struct GridwiseFpAintBGemm_Wmma
Number
<
K0PerWmma
>
{},
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
K
1
),
I
1
),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
,
I0
));
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
,
I0
));
}
}
}();
}();
...
@@ -424,6 +430,52 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -424,6 +430,52 @@ struct GridwiseFpAintBGemm_Wmma
return
b_wave_desc
;
return
b_wave_desc
;
}
}
template
<
typename
ScaleBlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeScaleWaveDescriptor
(
const
ScaleBlockDesc_
&
)
{
constexpr
auto
scale_wave_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
ScaleBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
ScaleBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
ScaleBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ScaleBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ScaleBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
ScaleBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
ScaleBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
}
}();
return
scale_wave_desc
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
// *Caution Here repeat is shuffle repeat
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
...
@@ -590,9 +642,10 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -590,9 +642,10 @@ struct GridwiseFpAintBGemm_Wmma
:
0
;
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
static
constexpr
auto
b_block_space_offset
=
(
a_block_space_offset
+
a_block_space_size_aligned
)
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
);
static
constexpr
auto
scale_block_space_offset
=
static
constexpr
auto
scale_block_space_offset
=
b_block_space_offset
+
b_block_space_size_aligned
;
(
b_block_space_offset
+
b_block_space_size_aligned
)
*
sizeof
(
BDataType
)
/
sizeof
(
ScaleDataType
)
;
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
static
constexpr
auto
c_shuffle_block_space_size
=
...
@@ -753,7 +806,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -753,7 +806,7 @@ struct GridwiseFpAintBGemm_Wmma
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
BDataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
SharedMemTrait
::
b_block_space_size_aligned
);
SharedMemTrait
::
b_block_space_size_aligned
);
// printf("b_lds_offset: %lu\n", SharedMemTrait::b_block_space_offset);
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -834,13 +887,15 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -834,13 +887,15 @@ struct GridwiseFpAintBGemm_Wmma
auto
scale_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
scale_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ScaleDataType
*>
(
p_shared
)
+
SharedMemTrait
::
scale_block_space_offset
,
static_cast
<
ScaleDataType
*>
(
p_shared
)
+
SharedMemTrait
::
scale_block_space_offset
,
SharedMemTrait
::
scale_block_space_size_aligned
);
SharedMemTrait
::
scale_block_space_size_aligned
);
// printf("scale_lds_offset: %lu\n", SharedMemTrait::scale_block_space_offset);
auto
scale_blockwise_copy
=
auto
scale_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
// Reduce slice length K1 to 1
Sequence
<
K0PerBlock
,
NPerBlock
,
I1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
ScaleDataType
,
ScaleDataType
,
...
@@ -851,10 +906,10 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -851,10 +906,10 @@ struct GridwiseFpAintBGemm_Wmma
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
1
,
// no effect
1
,
// no effect
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
...
@@ -926,7 +981,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -926,7 +981,7 @@ struct GridwiseFpAintBGemm_Wmma
AccDataType
,
AccDataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
decltype
(
Make
B
WaveDescriptor
(
scale_block_desc
)),
decltype
(
Make
Scale
WaveDescriptor
(
scale_block_desc
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
66e61076
...
@@ -581,9 +581,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -581,9 +581,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
typename
BBlockTransferStep
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleGridDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockTransfer
,
typename
ScaleGridBuffer
,
typename
ScaleGridBuffer
,
typename
ScaleBlockBuffer
,
typename
ScaleBlockBuffer
,
typename
ScaleBlockTransfer
,
typename
BlockwiseGemm
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
...
@@ -658,6 +658,116 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -658,6 +658,116 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
}
}
};
};
template
<
>
struct
GridwiseGemmPipeline_v1_dequant
<
1
,
true
,
false
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockTransfer
,
typename
ScaleGridBuffer
,
typename
ScaleBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
ScaleGridBuffer
&
scale_grid_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf
);
scale_blockwise_copy
.
Run
(
scale_grid_desc
,
scale_grid_buf
,
scale_block_desc
,
b_block_origin_idx
,
scale_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf_switch
);
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_buf
=
b_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
struct
GridwiseGemmPipelineInterwave_v1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment