Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d4b8f1e3
Commit
d4b8f1e3
authored
Feb 14, 2025
by
coderfeli
Browse files
add codes for a scatter
parent
82e1f1b9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
32 deletions
+55
-32
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+31
-16
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+4
-1
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+11
-7
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
...ary/reference_tensor_operation/cpu/reference_moe_gemm.hpp
+9
-8
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
d4b8f1e3
...
...
@@ -195,7 +195,10 @@ int main(int argc, char* argv[])
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
128
;
ck
::
index_t
batch
=
64
;
ck
::
index_t
topk
=
2
;
ck
::
index_t
tokens
=
batch
*
topk
;
if
(
argc
==
1
)
{
...
...
@@ -241,9 +244,14 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
SORTED_SIZE
;
i
++
)
{
int
tile_off
=
i
%
sorted_tile_size
;
if
(
tile_off
<
token_per_tile
)
sorted_token_ids
.
mData
[
i
]
=
tokenid
++
;
{
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
batch
)
&
((
tokenid
/
batch
)
<<
24
);
tokenid
++
;
}
else
{
sorted_token_ids
.
mData
[
i
]
=
tokens
;
}
}
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
...
...
@@ -252,14 +260,14 @@ int main(int argc, char* argv[])
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
tokens
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]}));
Tensor
<
EDataType
>
e_
m
_n_host_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_
m
_n_device_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_
t
_n_host_result
(
HostTensorDescriptor
({
tokens
,
topk
,
N
},
{
topk
*
N
,
N
,
1
}));
Tensor
<
EDataType
>
e_
t
_n_device_result
(
HostTensorDescriptor
({
tokens
,
topk
,
N
},
{
topk
*
N
,
N
,
1
}));
std
::
cout
<<
"a0_t_k: "
<<
a0_t_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_e_n: "
<<
d1_e_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_t_n: "
<<
d0_t_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_
m
_n: "
<<
e_
m
_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_
t
_n: "
<<
e_
t
_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -290,14 +298,14 @@ int main(int argc, char* argv[])
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_e_n_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_t_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_
m
_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_
t
_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_t_k
.
savetxt
(
"a.txt"
);
sorted_token_ids_dev
.
ToDevice
(
sorted_token_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_t_k
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_t_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_e_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_
m
_n_device_result
.
mData
.
data
());
//
e_device_buf.ToDevice(e_
t
_n_device_result.mData.data());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -322,6 +330,7 @@ int main(int argc, char* argv[])
d1_device_buf
.
GetDeviceBuffer
()},
e_device_buf
.
GetDeviceBuffer
(),
tokens
,
topk
,
SORTED_SIZE
,
N
,
K
,
...
...
@@ -359,9 +368,9 @@ int main(int argc, char* argv[])
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
,
0
,
1
});
e_device_buf
.
FromDevice
(
e_
m
_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_
t
_n_device_result
.
mData
.
data
());
Tensor
<
CShuffleDataType
>
c_
m
_n
({
SORTED_SIZE
,
N
});
Tensor
<
CShuffleDataType
>
c_
t_k
_n
({
tokens
,
topk
,
N
},
{
topk
*
N
,
N
,
1
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMoeGemm
<
A0DataType
,
B0DataType
,
...
...
@@ -374,25 +383,31 @@ int main(int argc, char* argv[])
auto
ref_invoker
=
ref_moe_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_moe_gemm
.
MakeArgument
(
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_t_k
,
b0_e_n_k
,
c_
m
_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_t_k
,
b0_e_n_k
,
c_
t_k
_n
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
ref_invoker
.
Run
(
ref_argument
);
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
{
const
int
t
=
sorted_token_ids
(
m
);
const
int
fuse_t
=
sorted_token_ids
(
m
);
const
int
t
=
fuse_t
&
0xffffff
;
if
(
t
>=
tokens
)
{
continue
;
}
const
int
topk_id
=
(
fuse_t
&
0xff000000
)
>>
24
;
const
int
e
=
expert_ids
(
m
/
sorted_tile_size
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_
m
_n_host_result
(
m
,
n
),
c_
m
_n
(
m
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
));
cde_element_op
(
e_
t
_n_host_result
(
t
,
topk_id
,
n
),
c_
t_k
_n
(
m
,
topk_id
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
));
}
}
e_device_buf
.
FromDevice
(
e_
m
_n_device_result
.
mData
.
data
());
e_
m
_n_device_result
.
savetxt
(
"out.txt"
);
e_
m
_n_host_result
.
savetxt
(
"ref.txt"
);
e_device_buf
.
FromDevice
(
e_
t
_n_device_result
.
mData
.
data
());
e_
t
_n_device_result
.
savetxt
(
"out.txt"
);
e_
t
_n_host_result
.
savetxt
(
"ref.txt"
);
return
ck
::
utils
::
check_err
(
e_
m
_n_device_result
,
e_
m
_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
e_
t
_n_device_result
,
e_
t
_n_host_result
,
"Error: Incorrect results!"
,
1e-3
,
5e-2
)
?
0
:
1
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
d4b8f1e3
...
...
@@ -505,6 +505,7 @@ struct DeviceMoeGemm
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
index_t
NumTokens
,
index_t
TopK
,
index_t
M
,
index_t
N
,
index_t
K
,
...
...
@@ -524,6 +525,7 @@ struct DeviceMoeGemm
p_ds
,
static_cast
<
CDataType
*>
(
p_c
),
NumTokens
,
TopK
,
M
,
N
,
K
,
...
...
@@ -563,7 +565,8 @@ struct DeviceMoeGemm
static_cast
<
const
BDataType
*>
(
p_b
),
p_ds
,
static_cast
<
CDataType
*>
(
p_c
),
M
,
M
,
//randoms set, no use
0
,
M
,
N
,
K
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
d4b8f1e3
...
...
@@ -545,6 +545,7 @@ struct GridwiseMoeGemmGather
index_t
KBatch_
)
:
NumTokens
{
NumTokens_
},
TopK
{
TopK_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
...
...
@@ -570,6 +571,7 @@ struct GridwiseMoeGemmGather
{
std
::
cout
<<
"problem {"
<<
"NumTokens:"
<<
NumTokens
<<
", "
<<
"TopK:"
<<
TopK
<<
", "
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
...
...
@@ -587,6 +589,7 @@ struct GridwiseMoeGemmGather
}
index_t
NumTokens
;
index_t
TopK
;
index_t
M
;
index_t
N
;
index_t
K
;
...
...
@@ -619,6 +622,7 @@ struct GridwiseMoeGemmGather
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
CDataType
*
p_c_grid_
,
index_t
NumTokens_
,
index_t
TopK_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
...
...
@@ -630,7 +634,7 @@ struct GridwiseMoeGemmGather
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
:
Problem
{
NumTokens_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideDs_
,
StrideC_
,
k_batch_
},
:
Problem
{
NumTokens_
,
TopK_
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideDs_
,
StrideC_
,
k_batch_
},
p_sorted_token_ids
{
p_sorted_token_ids_
},
p_sorted_expert_ids
{
p_sorted_expert_ids_
},
...
...
@@ -1155,10 +1159,10 @@ struct GridwiseMoeGemmGather
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
AKThreads
*
AMRepeats
;
const
index_t
t0
=
(
p_sorted_token_ids
[
block_m_id
*
MPerBlock
]
&
0xffffff
)
;
if
(
t0
>=
problem
.
NumTokens
)
const
index_t
t0
=
p_sorted_token_ids
[
block_m_id
*
MPerBlock
];
if
(
(
t0
&
0xffffff
)
>=
problem
.
NumTokens
)
return
;
const
index_t
topk_id
=
(
t0
&
0xff000000
)
>>
24
;
StaticallyIndexedArray
<
index_t
,
AMRepeats
>
gather_offsets
;
//= p_sorted_token_ids[token_pos];
static_for
<
0
,
AMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
gather_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
K
;
...
...
@@ -1450,7 +1454,7 @@ struct GridwiseMoeGemmGather
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I0
];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
0
;
scatter_offsets
(
m0
)
=
((
p_sorted_token_ids
[
c_token_pos
+
m0
]
&
0xffffff
)
*
problem
.
TopK
+
topk_id
)
*
problem
.
N
;
scatter_weights
(
m0
)
=
p_sorted_weights
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
...
...
@@ -1482,13 +1486,13 @@ struct GridwiseMoeGemmGather
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>
,
// ThreadTransferDstResetCoordinateAfterRunFlags
1
,
//ScatterDim
fals
e
,
//OutputScatter: false, only use scatter weights
tru
e
,
//OutputScatter: false, only use scatter weights
1
// ScatterWeightIdx: ascale
>
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
make_tuple
(
make_multi_index
(
0
,
0
,
block_n_id
,
0
)),
c_element_op
,
scatter_offsets
,
scatter_weights
};
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
View file @
d4b8f1e3
...
...
@@ -33,7 +33,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
const
index_t
sorted_tile_size
,
const
Tensor
<
ADataType
>&
a_t_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
Tensor
<
CDataType
>&
c_
m
_n
,
Tensor
<
CDataType
>&
c_
t_k
_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
...
...
@@ -42,7 +42,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
sorted_tile_size_
{
sorted_tile_size
},
a_t_k_
{
a_t_k
},
b_e_n_k_
{
b_e_n_k
},
c_
m
_n_
{
c_
m
_n
},
c_
t_k
_n_
{
c_
t_k
_n
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
...
...
@@ -54,7 +54,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
index_t
sorted_tile_size_
;
const
Tensor
<
ADataType
>&
a_t_k_
;
const
Tensor
<
BDataType
>&
b_e_n_k_
;
Tensor
<
CDataType
>&
c_
m
_n_
;
Tensor
<
CDataType
>&
c_
t_k
_n_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -74,7 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType
v_acc
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeB
v_b
{
0
};
const
int
t
=
arg
.
sorted_token_ids_
(
m
);
const
int
t
=
arg
.
sorted_token_ids_
(
m
)
&
0xffffff
;
const
int
topk_id
=
(
arg
.
sorted_token_ids_
(
m
)
&
0xff000000
)
>>
24
;
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
const
int
token_cnt
=
arg
.
a_t_k_
.
mDesc
.
GetLengths
()[
0
];
if
(
t
<
token_cnt
)
{
...
...
@@ -110,11 +111,11 @@ struct ReferenceMoeGemm : public device::BaseOperator
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_
m
_n_
(
m
,
n
)
=
v_c
;
arg
.
c_
t_k
_n_
(
t
,
topk_id
,
n
)
=
v_c
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
arg
.
c_m_n_
.
mDesc
.
GetLengths
()[
0
]
,
arg
.
c_
m
_n_
.
mDesc
.
GetLengths
()[
1
])(
f_mk_kn_mn
,
arg
.
sorted_tile_size_
,
arg
.
c_
t_k
_n_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
...
...
@@ -140,12 +141,12 @@ struct ReferenceMoeGemm : public device::BaseOperator
const
index_t
sorted_tile_size
,
const
Tensor
<
ADataType
>&
a_t_k
,
const
Tensor
<
BDataType
>&
b_e_n_k
,
Tensor
<
CDataType
>&
c_
m
_n
,
Tensor
<
CDataType
>&
c_
t_k
_n
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a_t_k
,
b_e_n_k
,
c_
m
_n
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a_t_k
,
b_e_n_k
,
c_
t_k
_n
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment