Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b02c0b82
Commit
b02c0b82
authored
Feb 11, 2025
by
coderfeli
Browse files
gemm1 scale debug
parent
e4ca61f9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
107 additions
and
51 deletions
+107
-51
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+22
-18
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
...block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
+6
-2
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+62
-16
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
...k/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
+4
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
.../thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
+11
-7
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
...ary/reference_tensor_operation/cpu/reference_moe_gemm.hpp
+2
-2
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
b02c0b82
...
@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
...
@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using
A0DataType
=
F8
;
using
A0DataType
=
F8
;
using
B0DataType
=
F8
;
using
B0DataType
=
F8
;
using
EDataType
=
F
16
;
using
EDataType
=
F
32
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
D0DataType
=
F32
;
using
D0DataType
=
F32
;
...
@@ -68,9 +68,11 @@ struct MulABScale
...
@@ -68,9 +68,11 @@ struct MulABScale
const
float
&
d1
,
const
float
&
d1
,
const
D2DataType
&
d2
)
const
const
D2DataType
&
d2
)
const
{
{
// const float x0_f = c * d0 * d1;
(
void
)
d2
;
// for gate, no d2 needed
(
void
)
d0
;
(
void
)
d1
;
(
void
)
d2
;
(
void
)
d0
;
const
float
x0_f
=
c
;
(
void
)
d1
;
const
float
x0_f
=
c
;
// const float x0_f = c;
e
=
ck
::
type_convert
<
EDataType
>
(
x0_f
);
e
=
ck
::
type_convert
<
EDataType
>
(
x0_f
);
}
}
};
};
...
@@ -91,8 +93,10 @@ struct MulABScaleSiluMulGate
...
@@ -91,8 +93,10 @@ struct MulABScaleSiluMulGate
const
D2DataType
&
d2
)
const
const
D2DataType
&
d2
)
const
{
{
// act
// act
(
void
)
d0
;
(
void
)
d1
;
(
void
)
d2
;
(
void
)
d2
;
float
x0
;
float
x0
=
0
;
ck
::
tensor_operation
::
element_wise
::
Silu
{}(
x0
,
c
*
d1
*
d0
);
ck
::
tensor_operation
::
element_wise
::
Silu
{}(
x0
,
c
*
d1
*
d0
);
// fuse mul
// fuse mul
e
=
ck
::
type_convert
<
EDataType
>
(
x0
);
e
=
ck
::
type_convert
<
EDataType
>
(
x0
);
...
@@ -145,7 +149,7 @@ using AElementOp = PassThrough;
...
@@ -145,7 +149,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MPerBlock
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
static
constexpr
ck
::
index_t
MXDLPerWave
=
MPerBlock
/
32
;
//todo fix this constraint
...
@@ -208,7 +212,7 @@ int main(int argc, char* argv[])
...
@@ -208,7 +212,7 @@ int main(int argc, char* argv[])
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
sorted_tile_size
=
MPerBlock
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
SORTED_SIZE
=
sorted_tile_num
*
sorted_tile_size
;
ck
::
index_t
tokens
=
51
2
;
ck
::
index_t
tokens
=
3
2
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -236,6 +240,8 @@ int main(int argc, char* argv[])
...
@@ -236,6 +240,8 @@ int main(int argc, char* argv[])
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideB
=
K
;
// ck::index_t StrideD = 0;
// ck::index_t StrideD = 0;
ck
::
index_t
StrideE
=
N
;
ck
::
index_t
StrideE
=
N
;
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
StrideDs
=
std
::
array
<
ck
::
index_t
,
NumDTensor
>
{
0
,
0
,
0
};
ck
::
index_t
KBatch
=
1
;
ck
::
index_t
KBatch
=
1
;
...
@@ -261,15 +267,15 @@ int main(int argc, char* argv[])
...
@@ -261,15 +267,15 @@ int main(int argc, char* argv[])
Tensor
<
A0DataType
>
a0_t_k
(
HostTensorDescriptor
({
tokens
,
K
},
{
K
,
1
}));
Tensor
<
A0DataType
>
a0_t_k
(
HostTensorDescriptor
({
tokens
,
K
},
{
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
N
,
1
},
{
1
,
0
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
tokens
,
N
},
{
StrideDs
[
0
]
,
0
}));
Tensor
<
D1DataType
>
d1_
m
_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
D1DataType
>
d1_
e
_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]
}));
Tensor
<
D2DataType
>
d2_m_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
D2DataType
>
d2_m_n
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_m_n_host_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_m_n_host_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_m_n_device_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
Tensor
<
EDataType
>
e_m_n_device_result
(
HostTensorDescriptor
({
SORTED_SIZE
,
N
},
{
N
,
1
}));
std
::
cout
<<
"a0_t_k: "
<<
a0_t_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a0_t_k: "
<<
a0_t_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_
m
_n: "
<<
d1_
m
_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d1_
e
_n: "
<<
d1_
e
_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d2_m_n: "
<<
d2_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d2_m_n: "
<<
d2_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_t_n: "
<<
d0_t_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"d0_t_n: "
<<
d0_t_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"e_m_n: "
<<
e_m_n_host_result
.
mDesc
<<
std
::
endl
;
...
@@ -281,21 +287,21 @@ int main(int argc, char* argv[])
...
@@ -281,21 +287,21 @@ int main(int argc, char* argv[])
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_
m
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d1_
e
_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
2
,
2
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D0DataType
>
{});
d1_
m
_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
d1_
e
_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D1DataType
>
{});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D2DataType
>
{});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_1
<
D2DataType
>
{});
break
;
break
;
default:
default:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
B0DataType
>
{
-
0.5
,
0.5
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d1_
m
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d1_
e
_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
0.0
,
1.0
});
d2_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D2DataType
>
{
0.0
,
1.0
});
}
}
DeviceMem
sorted_token_ids_dev
(
sizeof
(
ck
::
index_t
)
*
sorted_token_ids
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
sorted_token_ids_dev
(
sizeof
(
ck
::
index_t
)
*
sorted_token_ids
.
mDesc
.
GetElementSpaceSize
());
...
@@ -303,7 +309,7 @@ int main(int argc, char* argv[])
...
@@ -303,7 +309,7 @@ int main(int argc, char* argv[])
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_t_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_t_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_e_n_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_e_n_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_t_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_t_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_
m
_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_
e
_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d2_device_buf
(
sizeof
(
D2DataType
)
*
d2_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
a0_t_k
.
savetxt
(
"a.txt"
);
a0_t_k
.
savetxt
(
"a.txt"
);
...
@@ -311,7 +317,7 @@ int main(int argc, char* argv[])
...
@@ -311,7 +317,7 @@ int main(int argc, char* argv[])
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
expert_ids_dev
.
ToDevice
(
expert_ids
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_t_k
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_t_k
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_t_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_t_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_
m
_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_
e
_n
.
mData
.
data
());
d2_device_buf
.
ToDevice
(
d2_m_n
.
mData
.
data
());
d2_device_buf
.
ToDevice
(
d2_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
...
@@ -319,8 +325,6 @@ int main(int argc, char* argv[])
...
@@ -319,8 +325,6 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
auto
cde_element_op
=
CDEElementOp
{};
constexpr
ck
::
index_t
NumDTensor
=
DsDataType
::
Size
();
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
// do GEMM
// do GEMM
...
@@ -404,7 +408,7 @@ int main(int argc, char* argv[])
...
@@ -404,7 +408,7 @@ int main(int argc, char* argv[])
const
int
t
=
sorted_token_ids
(
m
);
const
int
t
=
sorted_token_ids
(
m
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_t_n
(
t
,
n
),
d1_
m
_n
(
m
,
n
),
d2_m_n
(
m
,
n
));
cde_element_op
(
e_m_n_host_result
(
m
,
n
),
c_m_n
(
m
,
n
),
d0_t_n
(
t
,
n
),
d1_
e
_n
(
m
,
n
),
d2_m_n
(
m
,
n
));
}
}
}
}
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp
View file @
b02c0b82
...
@@ -43,13 +43,15 @@ template <typename ThreadGroup,
...
@@ -43,13 +43,15 @@ template <typename ThreadGroup,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
,
index_t
ScatterDim
=
1
,
index_t
ScatterDim
=
1
,
bool
OutputScatter
=
true
,
index_t
ScatterWeightIdx
=
3
,
index_t
NumThreadScratch
=
1
>
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v7r3_scatter
struct
ThreadGroupTensorSliceTransfer_v7r3_scatter
{
{
static
constexpr
index_t
nDim
=
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
mod_num
=
ThreadClusterLengths
{}.
At
(
Number
<
3
>
{});
// Dirty HACK FELIX, TODO fix
static
constexpr
index_t
mod_num
=
ThreadClusterLengths
{}.
At
(
Number
<
3
>
{})
;
// Dirty HACK FELIX, TODO fix
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
...
@@ -114,7 +116,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
...
@@ -114,7 +116,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
Number
<
nSrc
>
{});
Number
<
nSrc
>
{});
const
auto
dst_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
const
auto
dst_thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()
%
mod_num
));
make_multi_index
(
OutputScatter
?
ThreadGroup
::
GetThreadId
()
%
mod_num
:
ThreadGroup
::
GetThreadId
()
));
const
auto
dst_thread_slice_origins
=
generate_tuple
(
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
dst_thread_cluster_idx
*
thread_slice_lengths
;
},
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
dst_thread_cluster_idx
*
thread_slice_lengths
;
},
Number
<
nDst
>
{});
Number
<
nDst
>
{});
...
@@ -219,6 +221,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
...
@@ -219,6 +221,8 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
,
ScatterDim
,
ScatterDim
,
OutputScatter
,
ScatterWeightIdx
,
NumThreadScratch
>
;
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
ThreadwiseTransfer
threadwise_transfer_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
b02c0b82
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3
_scatter
.hpp"
#define DEBUG_LOG 0
#define DEBUG_LOG 0
...
@@ -486,13 +486,36 @@ struct GridwiseMoeGemmGather
...
@@ -486,13 +486,36 @@ struct GridwiseMoeGemmGather
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
DLayout
>
__host__
__device__
static
auto
MakeDGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I0
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
DLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I0
,
StrideC
));
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
Make
C
GridDescriptor_M_N
<
DLayout
>
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
return
Make
D
GridDescriptor_M_N
<
DLayout
>
(
M
,
MPad
,
N
,
NPad
,
StrideDs
[
i
]);
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
...
@@ -509,8 +532,6 @@ struct GridwiseMoeGemmGather
...
@@ -509,8 +532,6 @@ struct GridwiseMoeGemmGather
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
(
0
,
0
,
0
,
0
,
{}))
>
;
struct
Problem
struct
Problem
{
{
__host__
__device__
Problem
(
index_t
NumTokens_
,
__host__
__device__
Problem
(
index_t
NumTokens_
,
...
@@ -1158,10 +1179,6 @@ struct GridwiseMoeGemmGather
...
@@ -1158,10 +1179,6 @@ struct GridwiseMoeGemmGather
// if(threadIdx.x==0)
// if(threadIdx.x==0)
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
// printf("tid %d eid %d expert_stride %d bufsize %d\n",
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
// threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
@@ -1377,8 +1394,18 @@ struct GridwiseMoeGemmGather
...
@@ -1377,8 +1394,18 @@ struct GridwiseMoeGemmGather
const
auto
ds_grid_buf
=
generate_tuple
(
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
const
DDataType
*
ptr_
=
p_ds_grid
[
i
];
// hack logic here to support different kind of strides. todo fix it.
// ascale t, 1; bscale E, N, 1, move ptr to E
if
(
i
.
value
==
1
)
{
ptr_
+=
expert_id
*
(
problem
.
StrideDs
[
1
]
?
problem
.
StrideDs
[
1
]
*
problem
.
N
:
1
);
// if ( threadIdx.x ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p
_ds_grid
[
i
]
,
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
p
tr_
,
ds_grid_desc_m_n
[
i
].
GetElementSpaceSize
());
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
...
@@ -1411,11 +1438,23 @@ struct GridwiseMoeGemmGather
...
@@ -1411,11 +1438,23 @@ struct GridwiseMoeGemmGather
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
c_grid_desc_mblock_mperblock_nblock_nperblock
;
using
CDEBlockTransferCluster
Lengths_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CDEBlockTransferCluster
=
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
constexpr
auto
EMThreads
=
CDEBlockTransferCluster
{}.
At
(
I0
)
*
CDEBlockTransferCluster
{}.
At
(
I1
);
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
constexpr
auto
EMRepeats
=
MPerBlock
/
EMThreads
;
constexpr
auto
ENThreads
=
CDEBlockTransferCluster
{}.
At
(
I2
)
*
CDEBlockTransferCluster
{}.
At
(
I3
);
const
index_t
c_token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
ENThreads
*
EMRepeats
;
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I2
];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
0
;
scatter_weights
(
m0
)
=
p_sorted_weights
[
c_token_pos
+
m0
];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3_scatter
<
ThisThreadBlock
,
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
Tuple
<
EDataType
>
,
...
@@ -1428,7 +1467,7 @@ struct GridwiseMoeGemmGather
...
@@ -1428,7 +1467,7 @@ struct GridwiseMoeGemmGather
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferCluster
Lengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferCluster
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename SrcDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DstDimAccessOrder,
...
@@ -1440,13 +1479,21 @@ struct GridwiseMoeGemmGather
...
@@ -1440,13 +1479,21 @@ struct GridwiseMoeGemmGather
Sequence
<
true
>
,
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
Sequence
<
false
>
,
// ThreadTransferDstResetCoordinateAfterRunFlags
1
,
//ScatterDim
false
,
//OutputScatter: false, only use scatter weights
1
// ScatterWeightIdx: ascale
>
{
c_ds_desc_refs
,
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
make_tuple
(
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
)),
c_element_op
};
c_element_op
,
scatter_offsets
,
scatter_weights
};
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
@@ -1472,7 +1519,6 @@ struct GridwiseMoeGemmGather
...
@@ -1472,7 +1519,6 @@ struct GridwiseMoeGemmGather
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
// printf("eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee\n");
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
block_sync_lds
();
block_sync_lds
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
View file @
b02c0b82
...
@@ -532,7 +532,6 @@ struct GridwiseMoeGemmScatter
...
@@ -532,7 +532,6 @@ struct GridwiseMoeGemmScatter
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
struct
Problem
struct
Problem
{
{
__host__
__device__
Problem
(
index_t
NumTokens_
,
__host__
__device__
Problem
(
index_t
NumTokens_
,
...
@@ -1427,15 +1426,14 @@ struct GridwiseMoeGemmScatter
...
@@ -1427,15 +1426,14 @@ struct GridwiseMoeGemmScatter
constexpr
auto
EMThreads
=
CDEBlockTransferCluster
{}.
At
(
I0
)
*
CDEBlockTransferCluster
{}.
At
(
I1
);
constexpr
auto
EMThreads
=
CDEBlockTransferCluster
{}.
At
(
I0
)
*
CDEBlockTransferCluster
{}.
At
(
I1
);
constexpr
auto
EMRepeats
=
MPerBlock
/
EMThreads
;
constexpr
auto
EMRepeats
=
MPerBlock
/
EMThreads
;
constexpr
auto
ENThreads
=
CDEBlockTransferCluster
{}.
At
(
I2
)
*
CDEBlockTransferCluster
{}.
At
(
I3
);
constexpr
auto
ENThreads
=
CDEBlockTransferCluster
{}.
At
(
I2
)
*
CDEBlockTransferCluster
{}.
At
(
I3
);
// static_assert(EMRepeats == 1, "only support 1 line per thread now!");
const
index_t
c_token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
ENThreads
*
EMRepeats
;
const
index_t
token_pos
=
block_m_id
*
MPerBlock
+
threadIdx
.
x
/
ENThreads
*
EMRepeats
;
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray
<
index_t
,
EMRepeats
>
scatter_offsets
;
//= p_sorted_token_ids[token_pos];
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
StaticallyIndexedArray
<
float
,
EMRepeats
>
scatter_weights
;
//= for topk
// too hack here, 2 specific for topk weights, fixme
// too hack here, 2 specific for topk weights, fixme
const
float
*
p_sorted_weights
=
p_ds_grid
[
I2
];
const
float
*
p_sorted_weights
=
p_ds_grid
[
I2
];
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
c_
token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
scatter_weights
(
m0
)
=
p_sorted_weights
[
token_pos
+
m0
];
scatter_weights
(
m0
)
=
p_sorted_weights
[
c_
token_pos
+
m0
];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
});
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp
View file @
b02c0b82
...
@@ -44,6 +44,8 @@ template <typename SrcDatas,
...
@@ -44,6 +44,8 @@ template <typename SrcDatas,
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
SrcResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
typename
DstResetCoordinateAfterRunFlags
,
// Sequence<bool ...>
index_t
ScatterDim
=
1
,
index_t
ScatterDim
=
1
,
bool
OutputScatter
=
true
,
index_t
ScatterWeightIdx
=
3
,
index_t
NumThreadScratch
=
1
>
index_t
NumThreadScratch
=
1
>
struct
ThreadwiseTensorSliceTransfer_v7r3_scatter
struct
ThreadwiseTensorSliceTransfer_v7r3_scatter
{
{
...
@@ -174,7 +176,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -174,7 +176,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
src_coords_
[
i
]);
src_coords_
[
i
]);
oob_val
=
oob_val
&
is_src_valid
;
oob_val
=
oob_val
&
is_src_valid
;
if
(
i
.
value
==
3
)
if
(
i
.
value
==
ScatterWeightIdx
)
{
{
static_assert
(
SrcScalarPerVectors
{}[
Number
<
2
>
{}]
==
1
,
"scatter weight dim, should only one vec"
);
static_assert
(
SrcScalarPerVectors
{}[
Number
<
2
>
{}]
==
1
,
"scatter weight dim, should only one vec"
);
constexpr
auto
iScatter
=
SrcSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
constexpr
auto
iScatter
=
SrcSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
...
@@ -187,8 +189,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -187,8 +189,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
using
DataType
=
remove_cvref_t
<
decltype
(
data_types
[
i
])
>
;
const
auto
tmp
=
const
auto
tmp
=
src_bufs
[
i
].
template
Get
<
DataType
>(
src_coords_
[
i
].
GetOffset
(),
true
);
src_bufs
[
i
].
template
Get
<
DataType
>(
src_coords_
[
i
].
GetOffset
(),
true
);
// if(i.value == 2)
// printf("tid %d srcid %d off %d v %f\n", threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp);
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}(
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
DataType
>()(
j
)
=
tmp
;
});
[
&
](
auto
j
)
{
src_vectors
(
i
).
template
AsType
<
DataType
>()(
j
)
=
tmp
;
});
}
}
...
@@ -420,8 +420,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -420,8 +420,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// loop over space-filling curve
// loop over space-filling curve
static_for
<
0
,
dst_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
static_for
<
0
,
dst_num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
auto
dst_vectors
=
dst_vectors_tuple_
[
thread_scratch_id
][
iAccess
];
auto
dst_vectors
=
dst_vectors_tuple_
[
thread_scratch_id
][
iAccess
];
constexpr
auto
iScatter
=
DstSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
auto
scatter_offset
=
0
;
const
auto
scatter_offset
=
scatter_offsets_
(
Number
<
iScatter
>
{});
if
constexpr
(
OutputScatter
)
{
constexpr
auto
iScatter
=
DstSpaceFillingCurve
::
GetIndex
(
iAccess
)(
Number
<
ScatterDim
>
{});
scatter_offset
=
scatter_offsets_
(
Number
<
iScatter
>
{});
}
// copy data from buf_vectors into dst_bufs
// copy data from buf_vectors into dst_bufs
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
using
dst_vector_t
=
typename
remove_cvref_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
using
dst_vector_t
=
typename
remove_cvref_t
<
decltype
(
dst_vectors
[
i
])
>::
type
;
...
@@ -459,7 +463,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -459,7 +463,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
Index
step_
;
Index
step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
step_
(
i
)
=
i
.
value
!
=
ScatterDim
?
forward_step
[
i
]
:
0
;
step_
(
i
)
=
(
i
.
value
=
=
ScatterDim
&&
OutputScatter
)
?
0
:
forward_step
[
i
];
// if(threadIdx.x==0)
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
...
@@ -530,7 +534,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
...
@@ -530,7 +534,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{
{
Index
step_
;
Index
step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
step_
(
i
)
=
i
.
value
!
=
ScatterDim
?
reset_step
[
Number
<
i
>
{}]
:
0
;
step_
(
i
)
=
(
i
.
value
=
=
ScatterDim
&&
OutputScatter
)
?
0
:
reset_step
[
Number
<
i
>
{}];
// if(threadIdx.x==0)
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), ordered_gather_dim);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
View file @
b02c0b82
...
@@ -49,8 +49,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -49,8 +49,9 @@ struct ReferenceMoeGemm : public device::BaseOperator
{
{
}
}
const
Tensor
<
ck
::
index_t
>&
expert_ids_
;
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids_
;
const
Tensor
<
ck
::
index_t
>&
sorted_token_ids_
;
const
Tensor
<
ck
::
index_t
>&
expert_ids_
;
index_t
sorted_tile_size_
;
const
Tensor
<
ADataType
>&
a_t_k_
;
const
Tensor
<
ADataType
>&
a_t_k_
;
const
Tensor
<
BDataType
>&
b_e_n_k_
;
const
Tensor
<
BDataType
>&
b_e_n_k_
;
Tensor
<
CDataType
>&
c_m_n_
;
Tensor
<
CDataType
>&
c_m_n_
;
...
@@ -58,7 +59,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -58,7 +59,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
sorted_tile_size_
;
};
};
// Invoker
// Invoker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment