Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8713ade3
Commit
8713ade3
authored
Dec 11, 2024
by
mtgu0705
Browse files
Enalbe splitK
parent
7a17ead7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
46 deletions
+69
-46
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
+12
-0
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
+51
-46
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
+1
-0
No files found.
example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp
View file @
8713ade3
...
...
@@ -45,6 +45,17 @@ using DeviceGemmV2Instance =
S
<
2
,
128
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
// 128, Scale_Block_N, Scale_Block_K,
// 16, 128,
// KPerBlock, 8, 32,
// 16, 16,
// 1, 4,
// S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 8, 8, 0,
// S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
// 2, 32, 32, 0,
// 1, 1, S<1, 16, 1, 8>, 4,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
CDataType
,
CDataType
,
false
,
PermuteB
>
;
// clang-format on
...
...
@@ -273,6 +284,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
StrideA
,
StrideB
,
StrideC
,
Scale_Stride_BN
,
static_cast
<
BScaleDataType
*>
(
b1_scale_device_buf
.
GetDeviceBuffer
()),
KBatch
,
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
8713ade3
...
...
@@ -100,6 +100,7 @@ struct DeviceGemmV2BScale : public BaseOperator
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideScaleB
,
const
void
*
p_b_scale
,
ck
::
index_t
KSplit
,
AElementwiseOperation
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
8713ade3
...
...
@@ -663,6 +663,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideScaleB
,
const
BScaleDataType
*
p_b_scale
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
...
...
@@ -678,6 +679,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
StrideA
,
StrideB
,
StrideC
,
StrideScaleB
,
p_b_scale
,
KBatch
,
a_element_op
,
...
...
@@ -697,6 +699,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideScaleB
,
const
void
*
p_b_scale
,
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
...
...
@@ -712,6 +715,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
StrideA
,
StrideB
,
StrideC
,
StrideScaleB
,
static_cast
<
const
BScaleDataType
*>
(
p_b_scale
),
KBatch
,
a_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp
View file @
8713ade3
...
...
@@ -37,18 +37,16 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
// GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.scale_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
// p_shared,
// karg);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_b_scale_grid
,
p_shared
,
karg
);
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_c_grid
+
splitk_batch_offset
.
c_reduce_offset
,
karg
.
p_b_scale_grid
+
splitk_batch_offset
.
scale_k_split_offset
,
p_shared
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -72,24 +70,17 @@ __global__ void
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
// GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
// karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
// karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
// karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
// p_shared_0,
// p_shared_1,
// karg);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
p_b_scale_grid
,
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_c_grid
+
splitk_batch_offset
.
c_reduce_offset
,
karg
.
p_b_scale_grid
+
splitk_batch_offset
.
scale_k_split_offset
,
p_shared_0
,
p_shared_1
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
...
...
@@ -533,6 +524,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideScaleB_
,
index_t
KBatch_
)
:
M
{
M_
},
N
{
N_
},
...
...
@@ -540,6 +532,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
StrideScaleB
{
StrideScaleB_
},
KBatch
{
KBatch_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
...
...
@@ -561,6 +554,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"SScaleB:"
<<
StrideScaleB
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
...
...
@@ -577,6 +571,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
StrideScaleB
;
index_t
KBatch
;
index_t
MPadded
;
index_t
NPadded
;
...
...
@@ -600,13 +595,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
StrideScaleB_
,
const
BScaleType
*
p_b_scale_grid_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
,
bool
is_reduce_
=
false
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
k_batch_
},
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
StrideScaleB_
,
k_batch_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
...
...
@@ -670,15 +666,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
//
//
Calculate B scale offset
//
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
//
{
//
scale_k_split_offset = blockIdx.z * (karg.K
/ 64
) * karg.StrideB;
//
}
//
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
//
{
//
scale_k_split_offset = blockIdx.z * (karg.K
/ 64) * karg.N
;
//
}
// Calculate B scale offset
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
scale_k_split_offset
=
blockIdx
.
z
*
(
karg
.
K
Read
/
ScaleBlockK
)
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
scale_k_split_offset
=
blockIdx
.
z
*
(
karg
.
K
Read
/
ScaleBlockK
)
;
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
))
{
...
...
@@ -701,7 +697,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
//
index_t scale_k_split_offset; // New member for scale matrix offset
index_t
scale_k_split_offset
;
// New member for scale matrix offset
index_t
c_reduce_offset
;
};
...
...
@@ -1273,6 +1269,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
BScaleGridDesc_BN_AK
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
...
@@ -1285,6 +1282,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
BScaleGridDesc_BN_AK
&
b_scale_grid_desc_bn_ak
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
)
{
...
...
@@ -1295,12 +1293,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// B Scale grid and buffer
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
// B Scale buffer
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
...
...
@@ -1703,8 +1696,15 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
// B Scale grid
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
problem
.
StrideScaleB
,
1
));
Run
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
...
...
@@ -1716,11 +1716,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_scale_grid_desc_bn_ak
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
}
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
BScaleGridDesc_BN_AK
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
...
@@ -1734,6 +1736,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
BScaleGridDesc_BN_AK
&
b_scale_grid_desc_bn_ak
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
)
{
...
...
@@ -1744,12 +1747,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// B Scale grid and buffer
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
),
1
));
// B Scale buffer
const
auto
b_scale_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_scale_grid
,
b_scale_grid_desc_bn_ak
.
GetElementSpaceSize
());
...
...
@@ -2164,8 +2162,14 @@ struct GridwiseGemm_xdl_cshuffle_v3
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
b_scale_grid_desc_bn_ak
=
make_naive_tensor_descriptor
(
make_tuple
(
math
::
integer_divide_ceil
(
problem
.
N
,
ScaleBlockN
),
math
::
integer_divide_ceil
(
problem
.
K
,
ScaleBlockK
)),
make_tuple
(
problem
.
StrideScaleB
,
1
));
Run_2Lds
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_scale_grid_desc_bn_ak
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
...
...
@@ -2178,6 +2182,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
problem
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_scale_grid_desc_bn_ak
,
c_grid_desc_mblock_mperblock_nblock_nperblock
);
}
};
...
...
profiler/include/profiler/profile_gemm_b_scale_impl.hpp
View file @
8713ade3
...
...
@@ -301,6 +301,7 @@ bool profile_gemm_b_scale_impl(int do_verification,
StrideA
,
StrideB
,
StrideC
,
Scale_Stride_BN
,
static_cast
<
BScaleDataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
kbatch_curr
,
a_element_op
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment