Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8a5bb9f3
Commit
8a5bb9f3
authored
Feb 08, 2025
by
coderfeli
Browse files
add files , build and run ok
parent
bd64a30b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
598 additions
and
30 deletions
+598
-30
example/65_gemm_multiply_multiply/CMakeLists.txt
example/65_gemm_multiply_multiply/CMakeLists.txt
+2
-1
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+0
-0
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+399
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+22
-28
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
...ry/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
+173
-0
No files found.
example/65_gemm_multiply_multiply/CMakeLists.txt
View file @
8a5bb9f3
...
...
@@ -5,4 +5,5 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_m
# target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable
(
example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp
)
add_example_executable
(
example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp
)
add_example_executable
(
example_moe_gemm_fp16 moe_gemm_fp16.cpp
)
add_example_executable
(
example_moe_gemm1 moe_gemm1.cpp
)
add_example_executable
(
example_moe_gemm2 moe_gemm2.cpp
)
example/65_gemm_multiply_multiply/moe_gemm
_fp16
.cpp
→
example/65_gemm_multiply_multiply/moe_gemm
1
.cpp
View file @
8a5bb9f3
File moved
example/65_gemm_multiply_multiply/moe_gemm2.cpp
0 → 100644
View file @
8a5bb9f3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
// using BF16 = ck::bhalf_t;
using
F8
=
ck
::
f8_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
A0DataType
=
F16
;
using
B0DataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
F32
;
using
D1DataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
EDataType
=
F16
;
using
A0Layout
=
Row
;
using
B0Layout
=
Col
;
using
D0Layout
=
Row
;
using
D1Layout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
,
D1Layout
>
;
using
ELayout
=
Row
;
struct
MultiplyMultiply
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()
<
EDataType
,
float
,
float
,
float
>
(
EDataType
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
// const float x0_f = c * d0 * d1;
const
float
x0_f
=
c
;
// printf("epi %f\n", c);
e
=
ck
::
type_convert
<
EDataType
>
(
x0_f
);
}
// template <>
// __host__ __device__ constexpr void operator()<BF16, float, float, float>(BF16& e,
// const float& c,
// const float& d0,
// const float& d1) const
// {
// const float x0_f = c;
// // const float x0_f = c * d0 * d1;
// e = ck::type_convert<BF16>(x0_f);
// }
};
void
preShuffleBuffer
(
const
B0DataType
*
src
,
B0DataType
*
dst
,
int
N
,
int
K
,
int
NXdl
)
{
int
KPack
=
16
/
sizeof
(
B0DataType
);
int
NLane
=
NXdl
;
int
KLane
=
64
/
NLane
;
int
K0
=
K
/
(
KLane
*
KPack
);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
int
n0
=
n
/
NLane
;
int
n1
=
n
%
NLane
;
int
k0
=
k
/
(
KLane
*
KPack
);
tempk
=
k
%
(
KLane
*
KPack
);
int
k1
=
tempk
/
KPack
;
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
K0
+
k0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
}
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
MultiplyMultiply
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
BK1
=
16
/
sizeof
(
B0DataType
);
static
constexpr
ck
::
index_t
EVec
=
16
/
sizeof
(
EDataType
);
// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RCR
// kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
//threadnum, mblock, nblock, kblock
256
,
MPerBlock
,
128
,
KPerBlock
,
// ak1, bk1
AK1
,
BK1
,
// mn_perxdl
32
,
32
,
// mn_xdlperwave
MXDLPerWave
,
1
,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
AK1
,
AK1
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
AK1
,
AK1
,
0
,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
EVec
,
EVec
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v1
,
A0DataType
>
;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
// clang-format on
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
true
;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
// GEMM shape
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
experts
=
1
;
ck
::
index_t
sorted_tile_num
=
1
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
1
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
6
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
N
=
std
::
stoi
(
argv
[
4
]);
K
=
std
::
stoi
(
argv
[
5
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: N, K
\n
"
);
exit
(
0
);
}
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
ck
::
index_t
KBatch
=
1
;
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
({
row
,
col
},
{
stride
,
1
_uz
});
}
else
{
return
HostTensorDescriptor
({
row
,
col
},
{
1
_uz
,
stride
});
}
};
// const ck::index_t experts = 8;
Tensor
<
ck
::
index_t
>
expert_ids
(
HostTensorDescriptor
({
experts
},
{
1
}));
Tensor
<
ck
::
index_t
>
sorted_token_ids
(
HostTensorDescriptor
({
SORTED_SIZE
},
{
1
}));
for
(
int
i
=
0
;
i
<
sorted_tile_num
;
i
++
)
{
expert_ids
.
mData
[
i
]
=
i
;
}
int
token_per_tile
=
tokens
/
sorted_tile_num
;
int
tokenid
=
0
;
// sorted_token_ids.mData[0] = 0;
for
(
int
i
=
0
;
i
<
SORTED_SIZE
;
i
++
)
{
int
tile_off
=
i
%
sorted_tile_size
;
if
(
tile_off
<
token_per_tile
)
sorted_token_ids
.
mData
[
i
]
=
tokenid
++
;
else
sorted_token_ids
.
mData
[
i
]
=
tokens
;
}
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
Tensor
<
A0DataType
>
a0_m_k
(
HostTensorDescriptor
({
SORTED_SIZE
,
K
},
{
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
// Tensor<B0DataType> b0_e_n_k(f_host_tensor_descriptor(K, N * experts, StrideB, B0Layout{}));
// Tensor<B0DataType> b0_preshuffled(
// f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor
<
D0DataType
>
d0_t_n
(
f_host_tensor_descriptor
(
tokens
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
D1DataType
>
d1_t_n
(
f_host_tensor_descriptor
(
tokens
,
N
,
StrideD
,
D1Layout
{}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_t_n: "
<<
d1_t_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_t_n: "
<<
d0_t_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_t_n: "
<<
e_t_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d1_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d1_t_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
}
DeviceMem
sorted_token_ids_dev
(
sizeof
(
ck
::
index_t
)
*
sorted_token_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
expert_ids_dev
(
sizeof
(
ck
::
index_t
)
*
expert_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_e_n_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_t_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_t_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_t_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_m_k
.
savetxt
(
"a.txt"
);
sorted_token_ids_dev
.
ToDevice
(
sorted_token_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_t_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_t_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_t_n_device_result
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
int
NPerXdl
=
device_op
.
GetPreShuffleParameters
();
preShuffleBuffer
(
b0_e_n_k
.
mData
.
data
(),
b0_preshuffled
.
mData
.
data
(),
N
*
experts
,
K
,
NPerXdl
);
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
auto
invoker
=
device_op
.
MakeInvoker
();
auto
argument
=
device_op
.
MakeArgument
(
sorted_token_ids_dev
.
GetDeviceBuffer
(),
expert_ids_dev
.
GetDeviceBuffer
(),
a0_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
NumDTensor
>
{
d0_device_buf
.
GetDeviceBuffer
(),
d1_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
tokens
,
SORTED_SIZE
,
N
,
K
,
StrideA
,
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
I0
,
I0
},
StrideE
,
KBatch
,
a_element_op
,
b_element_op
,
cde_element_op
);
if
(
!
device_op
.
IsSupportedArgument
(
argument
))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
if
(
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
SORTED_SIZE
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
SORTED_SIZE
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
SORTED_SIZE
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
}
if
(
do_verification
)
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
,
0
,
1
});
e_device_buf
.
FromDevice
(
e_t_n_device_result
.
mData
.
data
());
Tensor
<
CShuffleDataType
>
c_t_n
({
tokens
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMoeGemm2
<
A0DataType
,
B0DataType
,
CShuffleDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_moe_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_moe_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_moe_gemm
.
MakeArgument
(
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_m_k
,
b0_e_n_k
,
c_t_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
{
const
int
t
=
sorted_token_ids
(
m
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_t_n_host_result
(
t
,
n
),
c_t_n
(
t
,
n
),
d0_t_n
(
t
,
n
),
d1_t_n
(
t
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_t_n_device_result
.
mData
.
data
());
e_t_n_device_result
.
savetxt
(
"out.txt"
);
e_t_n_host_result
.
savetxt
(
"ref.txt"
);
return
ck
::
utils
::
check_err
(
e_t_n_device_result
,
e_t_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
?
0
:
1
;
}
return
0
;
}
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
View file @
8a5bb9f3
...
...
@@ -48,6 +48,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
mod_num
=
ThreadClusterLengths
{}.
At
(
Number
<
3
>
{});
// Dirty HACK FELIX, TODO fix
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
...
...
@@ -101,7 +102,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
make_multi_index
(
ThreadGroup
::
GetThreadId
()
%
mod_num
));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
8a5bb9f3
...
...
@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1
_mod8
.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
@@ -1109,12 +1109,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
ignore
=
b_element_op
;
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
...
...
@@ -1125,19 +1125,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
expert_id
=
__builtin_amdgcn_readfirstlane
(
p_sorted_expert_ids
[
block_m_id
]);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr
auto
AMThreads
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I1
);
constexpr
auto
AK0Threads
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I0
);
constexpr
auto
AK1Threads
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I2
);
constexpr
auto
AKThreads
=
AK0Threads
*
AK1Threads
;
constexpr
auto
AMRepeats
=
MPerBlock
/
AMThreads
;
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
AKThreads
*
AMRepeats
;
StaticallyIndexedArray
<
index_t
,
AMRepeats
>
gather_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
AMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
gather_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
K
;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
expert_stride
=
__builtin_amdgcn_readfirstlane
(
problem
.
N
*
problem
.
K
);
...
...
@@ -1154,10 +1141,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
...
@@ -1166,7 +1149,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
_mod8
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -1187,15 +1170,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
1
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
gather_offsets
);
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
...
...
@@ -1406,10 +1387,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferCluster
Lengths_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CDEBlockTransferCluster
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
constexpr
auto
EMThreads
=
CDEBlockTransferCluster
{}.
At
(
I0
)
*
CDEBlockTransferCluster
{}.
At
(
I1
);
constexpr
auto
EMRepeats
=
MPerBlock
/
EMThreads
;
static_assert
(
EMRepeats
==
1
,
"only support 1 line per thread now!"
);
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
EMThreads
*
EMRepeats
;
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
...
...
@@ -1423,7 +1414,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferCluster
Lengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferCluster
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
...
...
@@ -1439,9 +1430,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
make_tuple
(
make_multi_index
(
0
,
0
,
block_n_id
,
0
)),
c_element_op
};
// if(threadIdx.x== 0)
// printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
+
scatter_offsets
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
-
scatter_offsets
(
I0
));
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
0 → 100644
View file @
8a5bb9f3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceMoeGemm2
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids
,
const
Tensor
<
ck
::
index_t
>&
expert_ids
,
const
index_t
sorted_tile_size
,
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
Tensor
<
CDataType
>&
c_t_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
sorted_token_ids_
{
sorted_token_ids
},
expert_ids_
{
expert_ids
},
sorted_tile_size_
{
sorted_tile_size
},
a_m_k_
{
a_m_k
},
b_e_n_k_
{
b_e_n_k
},
c_t_n_
{
c_t_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
Tensor
<
ck
::
index_t
>&
expert_ids_
;
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids_
;
const
Tensor
<
ADataType
>&
a_m_k_
;
const
Tensor
<
BDataType
>&
b_e_n_k_
;
Tensor
<
CDataType
>&
c_t_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
sorted_tile_size_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceMoeGemm2
::
Argument
;
float
Run
(
const
Argument
&
arg
)
{
arg
.
c_t_n_
.
SetZero
();
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeB
v_b
{
0
};
const
int
t
=
arg
.
sorted_token_ids_
(
m
);
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
const
int
token_cnt
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
0
];
if
(
t
<
token_cnt
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
else
{
arg
.
a_element_op_
(
v_a
,
arg
.
a_m_k_
(
m
,
k
));
}
// same for B matrix
if
constexpr
(
is_same_v
<
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
ConvertBF16RTN
>
)
{
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}(
v_b
,
arg
.
b_e_n_k_
(
e
,
n
,
k
));
}
else
{
arg
.
b_element_op_
(
v_b
,
arg
.
b_e_n_k_
(
e
,
n
,
k
));
}
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
CDataType
v_c
{
0
};
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_t_n_
(
t
,
n
)
+=
v_c
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
arg
.
c_t_n_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_t_n_
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids
,
const
Tensor
<
ck
::
index_t
>&
expert_ids
,
const
index_t
sorted_tile_size
,
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
Tensor
<
CDataType
>&
c_t_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a_m_k
,
b_e_n_k
,
c_t_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceMoeGemm2"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment