Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
be79b63b
Commit
be79b63b
authored
Feb 15, 2025
by
mtgu0705
Browse files
fix bug in moe_gemm1.cpp, now function pass.
parent
1b0b7810
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
221 additions
and
29 deletions
+221
-29
example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp
example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp
+190
-26
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+28
-0
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+3
-3
No files found.
example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp
View file @
be79b63b
...
...
@@ -89,17 +89,70 @@ struct MulABScaleSilu
}
};
// using DsLayout = DsLayoutGate;
// using DsDataType = DsDataTypeGate;
using
CDEElementOp
=
MulABScale
;
// using CDEElementOp = MulABScaleSiluMulGate;
#if 1
void
preShuffleBuffer
(
const
B0DataType
*
src
,
B0DataType
*
dst
,
int
N
,
int
K
,
int
NXdl
)
{
int
KPack
=
32
;
int
NLane
=
NXdl
;
int
KLane
=
64
/
NLane
;
int
K0
=
K
/
(
KLane
*
KPack
);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
int
n0
=
n
/
NLane
;
int
n1
=
n
%
NLane
;
int
k0
=
k
/
(
KLane
*
KPack
);
tempk
=
k
%
(
KLane
*
KPack
);
int
k1
=
tempk
/
KPack
;
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
K0
+
k0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
dst
[
outputIndex
/
2
]
=
src
[(
n
*
K
+
k
)
/
2
];
}
}
}
#endif
float
i4_to_f32_gfx9
(
uint8_t
i4
)
{
static
std
::
unordered_map
<
uint8_t
,
float
>
u
=
{{
0b1000
,
-
0.5000
f
},
{
0b1001
,
-
0.4375
f
},
{
0b1010
,
-
0.3750
f
},
{
0b1011
,
-
0.3125
f
},
{
0b1100
,
-
0.2500
f
},
{
0b1101
,
-
0.1875
f
},
{
0b1110
,
-
0.1250
f
},
{
0b1111
,
-
0.0625
f
},
{
0b0
,
+
0.0000
f
},
{
0b1
,
+
0.0625
f
},
{
0b10
,
+
0.1250
f
},
{
0b11
,
+
0.1875
f
},
{
0b100
,
+
0.2500
f
},
{
0b101
,
+
0.3125
f
},
{
0b110
,
+
0.3750
f
},
{
0b111
,
+
0.4375
f
}};
return
u
[
i4
];
}
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
#if 0
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
...
...
@@ -115,7 +168,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
Row, Col, DsLayout, ELayout,
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
25
6
,
MPerBlock
,
1
28
,
KPerBlock
,
6
4
, MPerBlock, 1
6
, KPerBlock,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, 1,
...
...
@@ -124,6 +177,23 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm<
CShuffleMXDLPerWave, 1, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, true, A0DataType>;
// clang-format on
#else
static
constexpr
ck
::
index_t
MPerBlock
=
16
;
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceMoeGemm
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
32
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
32
,
32
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
1
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
true
,
A0DataType
>
;
// clang-format on
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -138,11 +208,12 @@ int main(int argc, char* argv[])
// GEMM shape
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
experts
=
1
;
ck
::
index_t
sorted_tile_num
=
1
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
128
;
// ck::index_t tokens = 128;
ck
::
index_t
tokens
=
16
;
if
(
argc
==
1
)
{
...
...
@@ -169,7 +240,6 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideE
=
N
;
ck
::
index_t
batch_stride_B
=
K
*
N
;
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
StrideDs
=
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
0
,
0
};
...
...
@@ -194,8 +264,8 @@ int main(int argc, char* argv[])
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
Tensor
<
A0DataType
>
a0_t_k
(
HostTensorDescriptor
({
tokens
,
K
},
{
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
K
,
N
},
{
N
*
K
,
1
,
K
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
K
,
N
},
{
N
*
K
,
1
,
K
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
tokens
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]}));
Tensor
<
EDataType
>
e_m_n_host_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
...
...
@@ -217,10 +287,22 @@ int main(int argc, char* argv[])
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
1
,
3
});
break
;
case
2
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{
1
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{
1
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{
1
});
break
;
case
3
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{
1
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{
1
});
break
;
case
4
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{
1
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{
1
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{
1
});
break
;
default:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
...
...
@@ -238,6 +320,7 @@ int main(int argc, char* argv[])
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_t_k
.
savetxt
(
"a.txt"
);
sorted_token_ids_dev
.
ToDevice
(
sorted_token_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_t_k
.
mData
.
data
());
...
...
@@ -252,8 +335,9 @@ int main(int argc, char* argv[])
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
// preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
printf
(
"Start PreShuffle
\n
"
);
#if 1
preShuffleBuffer
(
b0_e_n_k
.
mData
.
data
(),
b0_preshuffled
.
mData
.
data
(),
N
*
experts
,
K
,
device_op
.
GetPreShuffleParameters
());
#else
// weight pre-shuffle
int
KPack
=
32
;
// int4 -> 32, fp8 -> 16, fp16 -> 8
int
NLane
=
device_op
.
GetPreShuffleParameters
();
...
...
@@ -279,20 +363,20 @@ int main(int argc, char* argv[])
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
K0
+
k0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
b0_preshuffled
(
e
*
batch_stride_B
+
outputIndex
)
=
b0_e_n_k
(
e
*
batch_stride_B
+
n
*
K
+
k
);
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
b0_preshuffled
(
e
,
outputIndex
%
K
,
outputIndex
/
K
)
=
b0_e_n_k
(
e
,
k
,
n
);
}
}
}
printf
(
"End PreShuffle, and Start vector permute
\n
"
);
#endif
// vector pk_i4x4 permute
for
(
int
e
=
0
;
e
<
experts
;
e
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
for
(
int
j
=
0
;
j
<
K
;
j
+=
8
)
{
int
input
[
8
];
...
...
@@ -341,7 +425,6 @@ int main(int argc, char* argv[])
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
printf
(
"End Permute and Start GEMM
\n
"
);
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
sorted_token_ids_dev
.
GetDeviceBuffer
(),
...
...
@@ -370,6 +453,7 @@ int main(int argc, char* argv[])
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
if
(
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
...
@@ -381,8 +465,8 @@ int main(int argc, char* argv[])
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
device_op
.
GetTypeString
()
<<
std
::
endl
;
}
if
(
do_verification
)
...
...
@@ -421,11 +505,91 @@ int main(int argc, char* argv[])
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_m_n_device_result
.
savetxt
(
"out.txt"
);
e_m_n_host_result
.
savetxt
(
"ref.txt"
);
#if 0
printf("A Matrix:\n");
for(int t = 0; t < tokens; t++)
{
for(int k = 0; k < K; k++)
{
printf("%f,", ck::type_convert<float>(a0_t_k(t, k)));
}
printf("\n");
}
printf("\n");
printf("B Matrix:\n");
for(int e = 0; e < experts; e++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b0_e_n_k(e, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f,", i4_to_f32_gfx9(i4));
}
printf("\n");
}
printf("\n");
}
printf("\n");
printf("B preshuflled Matrix:\n");
for(int e = 0; e < experts; e++)
{
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b0_preshuffled(e, k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f,", i4_to_f32_gfx9(i4));
}
printf("\n");
}
printf("\n");
}
printf("\n");
printf("C device Matrix:\n");
for(int m = 0; m < SORTED_SIZE; m++)
{
for(int n = 0; n < N; n++)
{
printf("%f,", ck::type_convert<float>(e_m_n_device_result(m, n)));
}
printf("\n");
}
printf("\n");
printf("C host Matrix:\n");
for(int m = 0; m < SORTED_SIZE; m++)
{
for(int n = 0; n < N; n++)
{
printf("%f,", ck::type_convert<float>(e_m_n_host_result(m, n)));
}
printf("\n");
}
#endif
return
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
?
0
:
1
;
}
printf
(
"end of kernel
\n
"
);
return
0
;
}
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
be79b63b
...
...
@@ -362,6 +362,34 @@ struct DeviceMoeGemm
throw
std
::
runtime_error
(
"todo: only v1 & v2 support now"
);
}
}
#if 1
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
// if(arg.KBatch > 1)
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle<
// GridwiseGemm,
// false,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
{
const
auto
kernel
=
kernel_moe_gemm_gather
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
RunKernel
(
kernel
);
}
}
}
#endif
return
ave_time
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
be79b63b
...
...
@@ -1086,7 +1086,7 @@ struct GridwiseMoeGemmGather
}
// check gridwise gemm pipeline
#if
1
#if
0
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
...
...
@@ -1193,7 +1193,7 @@ struct GridwiseMoeGemmGather
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
+
expert_id
*
expert_stride
,
b_grid_desc_bpreshuffled
.
GetElementSpaceSize
());
p_b_grid
+
expert_id
*
expert_stride
/
BPackedSize
,
b_grid_desc_bpreshuffled
.
GetElementSpaceSize
());
// if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...
...
@@ -1248,7 +1248,7 @@ struct GridwiseMoeGemmGather
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
1
,
2
,
0
,
3
>
,
3
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment