Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b3ae04f8
Commit
b3ae04f8
authored
Feb 14, 2025
by
coderfeli
Browse files
fix ref gemm no padding
parent
1078d229
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
10 deletions
+9
-10
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+8
-8
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
...ary/reference_tensor_operation/cpu/reference_moe_gemm.hpp
+1
-2
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
b3ae04f8
...
...
@@ -132,7 +132,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
MPerBlock
=
32
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
static
constexpr
ck
::
index_t
CShuffleMXDLPerWave
=
MPerBlock
/
32
;
static
constexpr
ck
::
index_t
KPerBlock
=
256
/
sizeof
(
A0DataType
);
...
...
@@ -255,13 +255,13 @@ int main(int argc, char* argv[])
}
expert_ids
.
savetxt
(
"expert_ids.txt"
,
"int"
);
sorted_token_ids
.
savetxt
(
"sorted_token_ids.txt"
,
"int"
);
Tensor
<
A0DataType
>
a0_t_k
(
HostTensorDescriptor
({
tokens
,
K
},
{
K
,
1
}));
Tensor
<
A0DataType
>
a0_t_k
(
HostTensorDescriptor
({
batch
,
K
},
{
K
,
1
}));
Tensor
<
B0DataType
>
b0_e_n_k
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
B0DataType
>
b0_preshuffled
(
HostTensorDescriptor
({
experts
,
N
,
K
},
{
N
*
K
,
K
,
1
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
tokens
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D0DataType
>
d0_t_n
(
HostTensorDescriptor
({
batch
,
N
},
{
StrideDs
[
0
],
0
}));
Tensor
<
D1DataType
>
d1_e_n
(
HostTensorDescriptor
({
experts
,
N
},
{
1
,
StrideDs
[
1
]}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
tokens
,
topk
,
N
},
{
topk
*
N
,
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
tokens
,
topk
,
N
},
{
topk
*
N
,
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_host_result
(
HostTensorDescriptor
({
batch
,
topk
,
N
},
{
topk
*
N
,
N
,
1
}));
Tensor
<
EDataType
>
e_t_n_device_result
(
HostTensorDescriptor
({
batch
,
topk
,
N
},
{
topk
*
N
,
N
,
1
}));
std
::
cout
<<
"a0_t_k: "
<<
a0_t_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b0_e_n_k: "
<<
b0_e_n_k
.
mDesc
<<
std
::
endl
;
...
...
@@ -370,7 +370,7 @@ int main(int argc, char* argv[])
e_device_buf
.
FromDevice
(
e_t_n_device_result
.
mData
.
data
());
Tensor
<
CShuffleDataType
>
c_t_k_n
({
tokens
,
topk
,
N
},
{
topk
*
N
,
N
,
1
});
Tensor
<
CShuffleDataType
>
c_t_k_n
({
batch
,
topk
,
N
},
{
topk
*
N
,
N
,
1
});
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMoeGemm
<
A0DataType
,
B0DataType
,
...
...
@@ -401,8 +401,8 @@ int main(int argc, char* argv[])
const
int
e
=
expert_ids
(
m
/
sorted_tile_size
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
cde_element_op
(
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
m
,
topk_id
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
));
printf
(
"m %d
fuset %d %d
%d %f %f
\n
"
,
m
,
topk_id
,
t
,
n
,
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
m
,
topk_id
,
n
));
cde_element_op
(
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
t
,
topk_id
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
));
printf
(
"m %d
n %d topk %d token
%d %f %f
\n
"
,
m
,
n
,
topk_id
,
t
,
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
t
,
topk_id
,
n
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
View file @
b3ae04f8
...
...
@@ -74,8 +74,6 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType
v_acc
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeB
v_b
{
0
};
if
(
m
>=
max_sorted_num
)
return
;
const
int
t
=
arg
.
sorted_token_ids_
(
m
)
&
0xffffff
;
const
int
topk_id
=
(
arg
.
sorted_token_ids_
(
m
)
&
0xff000000
)
>>
24
;
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
...
...
@@ -112,6 +110,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_t_k_n_
(
t
,
topk_id
,
n
)
=
v_c
;
printf
(
"ref m %d n %d t %d topk %d v %f
\n
"
,
m
,
n
,
t
,
topk_id
,
v_c
);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment