Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8c0e03ba
Commit
8c0e03ba
authored
Feb 11, 2025
by
mtgu0705
Browse files
General fix.
parent
f0fba871
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
38 additions
and
12 deletions
+38
-12
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
.../blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
+9
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+16
-3
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
.../device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
...n/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
+3
-3
No files found.
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
View file @
8c0e03ba
...
@@ -22,7 +22,7 @@ using CLayout = Row;
...
@@ -22,7 +22,7 @@ using CLayout = Row;
void
preShuffleBuffer
(
const
I4
*
src
,
I4
*
dst
,
int
N
,
int
K
,
int
NXdl
)
void
preShuffleBuffer
(
const
I4
*
src
,
I4
*
dst
,
int
N
,
int
K
,
int
NXdl
)
{
{
int
KPack
=
32
;
int
KPack
=
32
;
// int4 -> 32, fp8 -> 16, fp16 -> 8
int
NLane
=
NXdl
;
int
NLane
=
NXdl
;
int
KLane
=
64
/
NLane
;
int
KLane
=
64
/
NLane
;
...
@@ -174,7 +174,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -174,7 +174,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_permute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
int
NperXdl
=
GetPreShuffleParameters
;
// do GEMM
auto
gemm
=
DeviceGemmV2Instance
{};
int
NperXdl
=
gemm
.
GetPreShuffleParameters
();
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_k_n_preshuffled
.
mData
.
data
(),
N
,
K
,
NperXdl
);
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_k_n_preshuffled
.
mData
.
data
(),
N
,
K
,
NperXdl
);
// weight permute
// weight permute
...
@@ -263,8 +266,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -263,8 +266,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmV2Instance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
float
ave_time
=
0
;
float
ave_time
=
0
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp
View file @
8c0e03ba
...
@@ -11,6 +11,15 @@
...
@@ -11,6 +11,15 @@
namespace
ck
{
namespace
ck
{
enum
struct
BlockGemmPipelineVersion
{
v1
,
// Naive
v2
,
// Mem
v3
,
// Comp
v4
,
// Comp, double lds buffer
v5
,
// Comp, double global prefetch register buffer
};
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
template
<
BlockGemmPipelineVersion
BlkGemmPipelineVer
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
BlockGemmPipelineScheduler
BlkGemmPipeSche
,
index_t
BlockSize
,
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
8c0e03ba
...
@@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base
static
constexpr
index_t
A_K0
=
ATileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K0
=
ATileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BTileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BTileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BTileDesc
{}.
GetLength
(
I2
);
// static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static
constexpr
index_t
B_K1
=
BBlockTransferSrcScalarPerVector
;
static
constexpr
auto
xdlops_gemm
=
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
ComputeDataType
,
MPerXDL
,
NPerXDL
,
KPack
,
ComputeDataType
,
TransposeC
>
{};
XdlopsGemm
<
ComputeDataType
,
MPerXDL
,
NPerXDL
,
KPack
,
ComputeDataType
,
TransposeC
>
{};
...
@@ -54,8 +55,9 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -54,8 +55,9 @@ struct BlockwiseGemmXdlops_pipeline_base
static
constexpr
index_t
AMmaKStride
=
KPack
;
static
constexpr
index_t
AMmaKStride
=
KPack
;
static
constexpr
index_t
BMmaKStride
=
KPack
;
static
constexpr
index_t
BMmaKStride
=
KPack
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
KPerInnerLoop
=
KPack
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
...
@@ -111,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -111,6 +113,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
}
__device__
static
auto
CalculateAThreadOriginDataIndex6D
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
0
,
xdlops_a_idx
[
I0
],
0
);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
View file @
8c0e03ba
...
@@ -142,8 +142,10 @@ struct DeviceGemmV2BPreshuffle : public BaseOperator
...
@@ -142,8 +142,10 @@ struct DeviceGemmV2BPreshuffle : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
bool
GetPermuteA
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
virtual
int
GetPreShuffleParameters
()
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
8c0e03ba
...
@@ -328,6 +328,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle<AL
...
@@ -328,6 +328,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle<AL
else
else
{
{
throw
std
::
runtime_error
(
"Only support pipeline ver v1, v2, v3 now!"
);
throw
std
::
runtime_error
(
"Only support pipeline ver v1, v2, v3 now!"
);
}
}
}
#if 0
#if 0
else
else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
8c0e03ba
...
@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
//
const BElementwiseOperation b_element_op{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
// divide block work by [M, N]
...
@@ -1514,7 +1514,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1514,7 +1514,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
//
const BElementwiseOperation b_element_op{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
// divide block work by [M, N]
...
@@ -1614,7 +1614,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1614,7 +1614,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(,
0
,
KRepeat
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
KRepeat
,
0
);
// Blockwise GEMM pipeline
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment