Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7796fc73
Commit
7796fc73
authored
Feb 15, 2025
by
coderfeli
Browse files
fix gemm2 scale, gemm2 ok now
parent
61e3c238
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
28 additions
and
23 deletions
+28
-23
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+2
-2
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+23
-18
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
...k/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
...ry/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
+1
-1
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
7796fc73
...
...
@@ -358,9 +358,9 @@ int main(int argc, char* argv[])
if
(
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
valid_
size
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
valid_
tile_num
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
valid_
size
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
valid_
size
*
N
;
sizeof
(
A0DataType
)
*
valid_
tile_num
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
valid_
tile_num
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
7796fc73
...
...
@@ -58,7 +58,7 @@ struct MulABScaleExpertWeight
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
constexpr
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
//real kernel use
//
for
real kernel use
template
<
>
__host__
__device__
constexpr
void
operator
()
<
EDataType
,
float
,
float
,
float
,
float
>
(
EDataType
&
e
,
...
...
@@ -67,11 +67,12 @@ struct MulABScaleExpertWeight
const
float
&
d1
,
const
float
&
d2
)
const
{
// e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(
void
)
d2
;
e
=
ck
::
type_convert
<
EDataType
>
(
c
);
//for real kernel use
//warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix
(
void
)
d0
;
e
=
ck
::
type_convert
<
EDataType
>
(
c
*
d1
*
d2
);
}
// for reference
// for reference
cpu
template
<
>
__host__
__device__
constexpr
void
operator
()
<
float
,
float
,
float
,
float
,
float
>
(
float
&
e
,
...
...
@@ -80,8 +81,8 @@ struct MulABScaleExpertWeight
const
float
&
d1
,
const
float
&
d2
)
const
{
(
void
)
d2
;
e
=
ck
::
type_convert
<
EDataType
>
(
c
);
// for reference cpu
e
=
ck
::
type_convert
<
EDataType
>
(
c
*
d0
*
d1
*
d2
);
}
};
...
...
@@ -124,7 +125,7 @@ using BElementOp = PassThrough;
using
CDEElementOp
=
MulABScaleExpertWeight
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
64
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
BLOCKSIZE
=
256
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
...
...
@@ -188,15 +189,13 @@ int main(int argc, char* argv[])
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
9
;
ck
::
index_t
sorted_tile_num
=
10
;
ck
::
index_t
valid_tile_num
=
8
;
ck
::
index_t
sorted_size
=
sorted_tile_num
*
MPerBlock
;
ck
::
index_t
valid_size
=
valid_tile_num
*
MPerBlock
;
ck
::
index_t
batch
=
64
;
ck
::
index_t
tokens
=
64
;
ck
::
index_t
topk
=
2
;
ck
::
index_t
tokens
=
batch
;
if
(
argc
==
1
)
{
// use default case
...
...
@@ -236,6 +235,11 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
sorted_tile_num
;
i
++
)
{
expert_ids
.
mData
[
i
]
=
i
;
}
if
(
tokens
*
topk
>
valid_size
)
{
printf
(
"err config, tokens * topk > valid_size
\n
"
);
exit
(
-
1
);
}
int
token_per_tile
=
tokens
*
topk
/
valid_tile_num
;
int
tokenid
=
0
;
// sorted_token_ids.mData[0] = 0;
...
...
@@ -243,20 +247,21 @@ int main(int argc, char* argv[])
int
tile_off
=
i
%
valid_size
;
if
(
tile_off
<
token_per_tile
)
{
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
batch
)
|
((
tokenid
/
batch
)
<<
24
);
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
tokens
)
|
((
tokenid
/
tokens
)
<<
24
);
tokenid
++
;
}
else
{
sorted_token_ids
.
mData
[
i
]
=
tokens
;
}
}
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
Tensor
<
A0DataType
>
a0_t_k_k
(
HostTensorDescriptor
({
batch
,
topk
,
K
},
{
topk
*
K
,
K
,
1
}));
Tensor
<
A0DataType
>
a0_t_k_k
(
HostTensorDescriptor
({
tokens
,
topk
,
K
},
{
topk
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
batch
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
tokens
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]}));
Tensor
<
D2DataType
>
d2_e_n
(
HostTensorDescriptor
({
sorted_size
,
N
},
{
1
,
0
}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
N
},
{
N
,
1
}));
...
...
@@ -274,7 +279,7 @@ int main(int argc, char* argv[])
case
0
:
break
;
case
1
:
a0_t_k_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
d2_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D2DataType
>
{
-
2
,
2
});
...
...
@@ -366,9 +371,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
valid_size
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
tokens
*
topk
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
valid_size
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
valid_size
*
N
;
sizeof
(
A0DataType
)
*
tokens
*
K
*
topk
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
tokens
*
N
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_scatter.hpp
View file @
7796fc73
...
...
@@ -1139,7 +1139,7 @@ struct GridwiseMoeGemmScatter
{
ignore
=
b_element_op
;
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
NumTokens
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
problem
.
NumTokens
*
problem
.
TopK
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
...
...
@@ -1459,7 +1459,7 @@ struct GridwiseMoeGemmScatter
static_for
<
0
,
EMRepeats
,
1
>
{}([
&
](
auto
m0
)
{
scatter_offsets
(
m0
)
=
(
p_sorted_token_ids
[
c_token_pos
+
m0
]
&
0xffffff
)
*
problem
.
N
;
scatter_weights
(
m0
)
=
p_sorted_weights_2
[
c_token_pos
+
m0
]
*
p_sorted_weights_0
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
*
p_sorted_weights_0
[(
c_token_pos
+
m0
)
*
problem
.
StrideDs
[
0
]];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp
View file @
7796fc73
...
...
@@ -125,7 +125,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
CDataType
v_c
{
0
};
D0DataType
v_d0
=
arg
.
d0_
(
m
,
n
);
// a
D0DataType
v_d1
=
arg
.
d1_
(
e
,
n
);
// b
arg
.
c_element_op_
(
v_c
,
v_acc
,
v_d0
*
v_topk_w
,
v_d1
,
v_topk_w
);
arg
.
c_element_op_
(
v_c
,
v_acc
,
v_d0
,
v_d1
,
v_topk_w
);
arg
.
c_t_n_
(
t
,
n
)
+=
v_c
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment