Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7ba5bff4
Commit
7ba5bff4
authored
Feb 08, 2025
by
coderfeli
Browse files
one tile ok
parent
8a5bb9f3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
31 deletions
+30
-31
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+11
-9
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
+6
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+4
-6
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
...ry/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
+8
-8
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
7ba5bff4
...
...
@@ -119,6 +119,7 @@ using CDEElementOp = MultiplyMultiply;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
static
constexpr
ck
::
index_t
AK1
=
16
/
sizeof
(
A0DataType
);
...
...
@@ -142,7 +143,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// ak1, bk1
AK1
,
BK1
,
// mn_perxdl
32
,
32
,
MNPerXDL
,
MNPerXDL
,
// mn_xdlperwave
MXDLPerWave
,
1
,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
...
...
@@ -173,11 +174,11 @@ int main(int argc, char* argv[])
// GEMM shape
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
experts
=
1
;
ck
::
index_t
sorted_tile_num
=
1
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
2
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
1
;
ck
::
index_t
tokens
=
32
;
if
(
argc
==
1
)
{
...
...
@@ -251,7 +252,7 @@ int main(int argc, char* argv[])
Tensor
<
D1DataType
>
d1_t_n
(
f_host_tensor_descriptor
(
tokens
,
N
,
StrideD
,
D1Layout
{}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
e_t_n_device_result
.
SetZero
();
std
::
cout
<<
"a0_m_k: "
<<
a0_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_t_n: "
<<
d1_t_n
.
mDesc
<<
std
::
endl
;
...
...
@@ -358,8 +359,7 @@ int main(int argc, char* argv[])
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
,
0
,
1
});
e_device_buf
.
FromDevice
(
e_t_n_device_result
.
mData
.
data
());
// e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor
<
CShuffleDataType
>
c_t_n
({
tokens
,
N
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMoeGemm2
<
A0DataType
,
...
...
@@ -376,10 +376,11 @@ int main(int argc, char* argv[])
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_m_k
,
b0_e_n_k
,
c_t_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
for
(
int
t
=
0
;
t
<
tokens
;
++
t
)
{
const
int
t
=
sorted_token_ids
(
m
);
//
const int t = sorted_token_ids(m);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_t_n_host_result
(
t
,
n
),
c_t_n
(
t
,
n
),
d0_t_n
(
t
,
n
),
d1_t_n
(
t
,
n
));
...
...
@@ -389,6 +390,7 @@ int main(int argc, char* argv[])
e_device_buf
.
FromDevice
(
e_t_n_device_result
.
mData
.
data
());
e_t_n_device_result
.
savetxt
(
"out.txt"
);
e_t_n_host_result
.
savetxt
(
"ref.txt"
);
return
ck
::
utils
::
check_err
(
e_t_n_device_result
,
e_t_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
?
0
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp
View file @
7ba5bff4
...
...
@@ -101,17 +101,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()
%
mod_num
));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_
data_idx_begin
;
},
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
src_
thread_
cluster_idx
*
thread_slice_lengths
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()
%
mod_num
));
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_
data_idx_begin
;
},
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
dst_
thread_
cluster_idx
*
thread_slice_lengths
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
7ba5bff4
...
...
@@ -1115,8 +1115,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
...
...
@@ -1393,14 +1391,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
constexpr
auto
EMThreads
=
CDEBlockTransferCluster
{}.
At
(
I0
)
*
CDEBlockTransferCluster
{}.
At
(
I1
);
constexpr
auto
EMRepeats
=
MPerBlock
/
EMThreads
;
constexpr
auto
ENThreads
=
CDEBlockTransferCluster
{}.
At
(
I2
)
*
CDEBlockTransferCluster
{}.
At
(
I3
);
static_assert
(
EMRepeats
==
1
,
"only support 1 line per thread now!"
);
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
E
M
Threads
*
EMRepeats
;
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
E
N
Threads
*
EMRepeats
;
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
//
printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
printf
(
"init off tid %d m %d off %d
\n
"
,
threadIdx
.
x
,
m0
(),
gather_offsets
(
m0
));
});
// printf("tid %d pos %d offset %d size %d\n", threadIdx.x, token_pos, scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
...
...
@@ -1433,7 +1433,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
make_tuple
(
make_multi_index
(
0
,
0
,
block_n_id
,
0
)),
c_element_op
};
// if(threadIdx.x== 0)
// printf("offset %d size %d\n", scatter_offsets(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
+
scatter_offsets
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
()
-
scatter_offsets
(
I0
));
// space filling curve for threadwise C in VGPR
...
...
@@ -1461,7 +1460,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
// printf("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n");
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
View file @
7ba5bff4
...
...
@@ -71,13 +71,13 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
arg
.
c_t_n_
.
SetZero
();
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
AccDataType
v_acc
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeB
v_b
{
0
};
const
int
t
=
arg
.
sorted_token_ids_
(
m
);
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
const
int
token_cnt
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
0
];
const
int
token_cnt
=
arg
.
c_t_n_
.
mDesc
.
GetLengths
()[
0
];
if
(
t
<
token_cnt
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
...
...
@@ -105,17 +105,17 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
CDataType
v_c
{
0
};
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_t_n_
(
t
,
n
)
+=
v_c
;
}
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
arg
.
c_t_n
_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_t_n_
.
mDesc
.
GetLengths
()[
1
])(
std
::
thread
::
hardware_concurrency
()
);
f_mk_kn_mn
,
arg
.
a_m_k
_
.
mDesc
.
GetLengths
()[
0
],
arg
.
c_t_n_
.
mDesc
.
GetLengths
()[
1
])(
1
);
return
0
;
}
...
...
script/cmake-ck-dev.sh
View file @
7ba5bff4
...
...
@@ -17,7 +17,7 @@ fi
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O
3
--save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_CXX_FLAGS
=
"-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O
1 -g
--save-temps -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment