Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66d08ea3
Commit
66d08ea3
authored
Feb 11, 2025
by
coderfeli
Browse files
impl topk weight scatter
parent
a8a82e0c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
104 additions
and
40 deletions
+104
-40
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+14
-16
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
...block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
...e/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
+34
-14
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
...k/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
+38
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
.../thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
+14
-5
No files found.
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
66d08ea3
...
@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight
...
@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
//
gpu
//
real kernel use
template
<
>
template
<
>
__host__
__device__
constexpr
void
operator
()
<
EDataType
,
float
,
float
,
float
,
float
>
__host__
__device__
constexpr
void
operator
()
<
EDataType
,
float
,
float
,
float
,
float
>
(
EDataType
&
e
,
(
EDataType
&
e
,
...
@@ -211,8 +211,9 @@ int main(int argc, char* argv[])
...
@@ -211,8 +211,9 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideB
=
K
;
// ck::index_t StrideD = 0;
ck
::
index_t
StrideE
=
N
;
ck
::
index_t
StrideE
=
N
;
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
StrideDs
=
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
0
,
0
,
0
};
ck
::
index_t
KBatch
=
1
;
ck
::
index_t
KBatch
=
1
;
...
@@ -238,9 +239,9 @@ int main(int argc, char* argv[])
...
@@ -238,9 +239,9 @@ int main(int argc, char* argv[])
Tensor
<
A0DataType
>
a0_m_k
(
HostTensorDescriptor
({
SORTED_SIZE
,
K
},
{
K
,
1
}));
Tensor
<
A0DataType
>
a0_m_k
(
HostTensorDescriptor
({
SORTED_SIZE
,
K
},
{
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
D0DataType
>
d0_
t
_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
0
,
0
}));
Tensor
<
D0DataType
>
d0_
m
_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
StrideDs
[
0
]
,
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
0
,
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]
}));
Tensor
<
D2DataType
>
d2_e_n
(
HostTensorDescriptor
({
experts
,
1
},
{
1
,
0
}));
Tensor
<
D2DataType
>
d2_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
0
}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
e_t_n_device_result
.
SetZero
();
e_t_n_device_result
.
SetZero
();
...
@@ -248,7 +249,7 @@ int main(int argc, char* argv[])
...
@@ -248,7 +249,7 @@ int main(int argc, char* argv[])
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d2_e_n: "
<<
d2_e_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d2_e_n: "
<<
d2_e_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_e_n: "
<<
d1_e_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_e_n: "
<<
d1_e_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_
t
_n: "
<<
d0_
t
_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_
m
_n: "
<<
d0_
m
_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_t_n: "
<<
e_t_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_t_n: "
<<
e_t_n_host_result
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
...
@@ -257,21 +258,21 @@ int main(int argc, char* argv[])
...
@@ -257,21 +258,21 @@ int main(int argc, char* argv[])
case
1
:
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_
t
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d0_
m
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
2
,
2
});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
d0_
t
_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d0_
m
_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D2DataType
>
{});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D2DataType
>
{});
break
;
break
;
default:
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
d0_
t
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d0_
m
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
0.0
,
1.0
});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
0.0
,
1.0
});
}
}
...
@@ -279,7 +280,7 @@ int main(int argc, char* argv[])
...
@@ -279,7 +280,7 @@ int main(int argc, char* argv[])
DeviceMem
expert_ids_dev
(
sizeof
(
ck
::
index_t
)
*
expert_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
expert_ids_dev
(
sizeof
(
ck
::
index_t
)
*
expert_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_e_n_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_e_n_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_
t
_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_
m
_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_e_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_t_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_t_n_device_result
.
mDesc
.
GetElementSpaceSize
());
...
@@ -287,7 +288,7 @@ int main(int argc, char* argv[])
...
@@ -287,7 +288,7 @@ int main(int argc, char* argv[])
sorted_token_ids_dev
.
ToDevice
(
sorted_token_ids
.
mData
.
data
());
sorted_token_ids_dev
.
ToDevice
(
sorted_token_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_
t
_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_
m
_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_e_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_e_n
.
mData
.
data
());
d2_device_buf
.
ToDevice
(
d2_e_n
.
mData
.
data
());
d2_device_buf
.
ToDevice
(
d2_e_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_t_n_device_result
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_t_n_device_result
.
mData
.
data
());
...
@@ -296,9 +297,6 @@ int main(int argc, char* argv[])
...
@@ -296,9 +297,6 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
// do GEMM
// do GEMM
auto
device_op
=
DeviceOpInstance
{};
auto
device_op
=
DeviceOpInstance
{};
...
@@ -325,7 +323,7 @@ int main(int argc, char* argv[])
...
@@ -325,7 +323,7 @@ int main(int argc, char* argv[])
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
I0
,
I0
,
I0
}
,
StrideDs
,
StrideE
,
StrideE
,
KBatch
,
KBatch
,
a_element_op
,
a_element_op
,
...
@@ -375,7 +373,7 @@ int main(int argc, char* argv[])
...
@@ -375,7 +373,7 @@ int main(int argc, char* argv[])
auto
ref_moe_gemm
=
ReferenceGemmInstance
{};
auto
ref_moe_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_moe_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_moe_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_moe_gemm
.
MakeArgument
(
auto
ref_argument
=
ref_moe_gemm
.
MakeArgument
(
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_m_k
,
b0_e_n_k
,
d0_
t
_n
,
d1_e_n
,
d2_e_n
,
c_t_n
,
PassThrough
{},
PassThrough
{},
cde_element_op
);
sorted_token_ids
,
expert_ids
,
sorted_tile_size
,
a0_m_k
,
b0_e_n_k
,
d0_
m
_n
,
d1_e_n
,
d2_e_n
,
c_t_n
,
PassThrough
{},
PassThrough
{},
cde_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
View file @
66d08ea3
...
@@ -64,13 +64,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
...
@@ -64,13 +64,15 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const
DstDescs
&
dst_descs
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
,
const
ElementwiseOperation
&
element_op
,
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
)
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
,
const
StaticallyIndexedArray
<
float
,
scatter_num
>
&
scatter_weights
)
:
threadwise_transfer_
(
src_descs
,
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
,
element_op
,
scatter_offsets
)
scatter_offsets
,
scatter_weights
)
{
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
...
...
include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
View file @
66d08ea3
...
@@ -305,23 +305,43 @@ struct DeviceMoeGemm
...
@@ -305,23 +305,43 @@ struct DeviceMoeGemm
// {
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// {
// const auto kernel = kernel_moe_gemm_gather<
// if constexpr (IsGatherGemm) {
// GridwiseGemm,
// const auto kernel = kernel_moe_gemm_gather<
// true,
// GridwiseGemm,
// InMemoryDataOperationEnum::AtomicAdd,
// true,
// minimum_occupancy,
// InMemoryDataOperationEnum::AtomicAdd,
// TailNumber::Odd>;
// minimum_occupancy,
// RunKernel(kernel);
// TailNumber::Odd>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// RunKernel(kernel);
// }
// }
// }
// else
// else
// {
// {
// const auto kernel = kernel_moe_gemm_gather<
// if constexpr (IsGatherGemm) {
// GridwiseGemm,
// const auto kernel = kernel_moe_gemm_gather<
// true,
// GridwiseGemm,
// InMemoryDataOperationEnum::AtomicAdd,
// true,
// minimum_occupancy,
// InMemoryDataOperationEnum::AtomicAdd,
// TailNumber::Even>;
// minimum_occupancy,
// RunKernel(kernel);
// TailNumber::Even>;
// RunKernel(kernel);
// else {
// const auto kernel = kernel_moe_gemm_scatter<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// RunKernel(kernel);
// }
// }
// }
// }
// }
// else
// else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
View file @
66d08ea3
...
@@ -486,13 +486,36 @@ struct GridwiseMoeGemmScatter
...
@@ -486,13 +486,36 @@ struct GridwiseMoeGemmScatter
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
DLayout
>
__host__
__device__
static
auto
MakeDGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I0
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
DLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I0
,
StrideC
));
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
Make
C
GridDescriptor_M_N
<
DLayout
>
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
return
Make
D
GridDescriptor_M_N
<
DLayout
>
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
...
@@ -509,7 +532,6 @@ struct GridwiseMoeGemmScatter
...
@@ -509,7 +532,6 @@ struct GridwiseMoeGemmScatter
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
0
,
0
,
0
,
0
,
{}))
>
;
struct
Problem
struct
Problem
{
{
...
@@ -1354,6 +1376,14 @@ struct GridwiseMoeGemmScatter
...
@@ -1354,6 +1376,14 @@ struct GridwiseMoeGemmScatter
const
auto
ds_grid_buf
=
generate_tuple
(
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
const
DDataType
*
ptr_
=
p_ds_grid
[
i
];
// hack logic here to support different kind of strides. todo fix it.
// ascale M, 1; bscale E, N, 1, move ptr to E
if
(
i
.
value
==
1
)
{
ptr_
+=
expert_id
*
problem
.
StrideDs
[
1
]
*
problem
.
N
;
}
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
p_ds_grid
[
i
],
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
},
...
@@ -1398,8 +1428,12 @@ struct GridwiseMoeGemmScatter
...
@@ -1398,8 +1428,12 @@ struct GridwiseMoeGemmScatter
// static_assert(EMRepeats == 1, "only support 1 line per thread now!");
// static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
ENThreads
*
EMRepeats
;
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
ENThreads
*
EMRepeats
;
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I2
];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
scatter_weights
(
m0
)
=
p_sorted_weights
[
token_pos
+
m0
];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
});
...
@@ -1435,7 +1469,8 @@ struct GridwiseMoeGemmScatter
...
@@ -1435,7 +1469,8 @@ struct GridwiseMoeGemmScatter
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
0
,
0
,
block_n_id
,
0
)),
make_tuple
(
make_multi_index
(
0
,
0
,
block_n_id
,
0
)),
c_element_op
,
c_element_op
,
scatter_offsets
};
scatter_offsets
,
scatter_weights
};
// if(threadIdx.x== 0)
// if(threadIdx.x== 0)
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
View file @
66d08ea3
...
@@ -99,11 +99,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -99,11 +99,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const
DstDescs
&
dst_descs
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_slice_origins
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_slice_origins
,
const
ElementwiseOperation
&
element_op
,
const
ElementwiseOperation
&
element_op
,
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
)
const
StaticallyIndexedArray
<
index_t
,
scatter_num
>
&
scatter_offsets
,
const
StaticallyIndexedArray
<
float
,
scatter_num
>
&
scatter_weights
)
:
src_coords_
(
MakeCoordinates
(
src_descs
,
src_slice_origins
)),
:
src_coords_
(
MakeCoordinates
(
src_descs
,
src_slice_origins
)),
dst_coords_
(
MakeCoordinates
(
dst_descs
,
dst_slice_origins
)),
dst_coords_
(
MakeCoordinates
(
dst_descs
,
dst_slice_origins
)),
element_op_
(
element_op
),
element_op_
(
element_op
),
scatter_offsets_
(
scatter_offsets
)
scatter_offsets_
(
scatter_offsets
),
scatter_weights_
(
scatter_weights
)
{
{
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! cannot evenly divide"
);
"wrong! cannot evenly divide"
);
...
@@ -172,14 +174,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -172,14 +174,20 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_
[
i
]);
src_coords_
[
i
]);
oob_val
=
oob_val
&
is_src_valid
;
oob_val
=
oob_val
&
is_src_valid
;
if
(
i
.
value
==
2
)
if
constexpr
(
SrcScalarPerVectors
{}[
i
]
==
1
)
{
static_assert
(
SrcScalarPerVectors
{}[
Number
<
2
>
{}]
==
1
,
"scatter weight dim, should only one vec"
);
constexpr
auto
iScatter
=
SrcSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
float
>()(
j
)
=
scatter_weights_
(
Number
<
iScatter
>
{});
});
}
else
if
constexpr
(
SrcScalarPerVectors
{}[
i
]
==
1
)
{
{
auto
data_types
=
SrcDatas
{};
auto
data_types
=
SrcDatas
{};
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
const
auto
tmp
=
const
auto
tmp
=
src_bufs
[
i
].
template
Get
<
DataType
>(
src_coords_
[
i
].
GetOffset
(),
true
);
src_bufs
[
i
].
template
Get
<
DataType
>(
src_coords_
[
i
].
GetOffset
(),
true
);
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
DataType
>()(
j
)
=
tmp
;
});
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
DataType
>()(
j
)
=
tmp
;
});
}
}
...
@@ -691,6 +699,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -691,6 +699,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
DstCoords
dst_coords_
;
DstCoords
dst_coords_
;
const
ElementwiseOperation
element_op_
;
const
ElementwiseOperation
element_op_
;
StaticallyIndexedArray
<
index_t
,
scatter_num
>
scatter_offsets_
;
StaticallyIndexedArray
<
index_t
,
scatter_num
>
scatter_offsets_
;
StaticallyIndexedArray
<
float
,
scatter_num
>
scatter_weights_
;
};
};
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment