Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1078d229
Commit
1078d229
authored
Feb 14, 2025
by
coderfeli
Browse files
add logics and debug
parent
d4b8f1e3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
9 deletions
+14
-9
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+6
-3
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
...ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
+1
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
...ary/reference_tensor_operation/cpu/reference_moe_gemm.hpp
+7
-6
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
1078d229
...
@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
...
@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
int
tile_off
=
i
%
sorted_tile_size
;
int
tile_off
=
i
%
sorted_tile_size
;
if
(
tile_off
<
token_per_tile
)
if
(
tile_off
<
token_per_tile
)
{
{
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
batch
)
&
((
tokenid
/
batch
)
<<
24
);
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
batch
)
|
((
tokenid
/
batch
)
<<
24
);
tokenid
++
;
tokenid
++
;
}
}
else
else
...
@@ -389,17 +389,20 @@ int main(int argc, char* argv[])
...
@@ -389,17 +389,20 @@ int main(int argc, char* argv[])
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
for
(
int
m
=
0
;
m
<
SORTED_SIZE
;
++
m
)
{
{
const
int
fuse_t
=
sorted_token_ids
(
m
)
;
const
int
fuse_t
=
sorted_token_ids
.
mData
[
m
]
;
const
int
t
=
fuse_t
&
0xffffff
;
const
int
t
=
fuse_t
&
0xffffff
;
const
int
topk_id
=
(
fuse_t
&
0xff000000
)
>>
24
;
printf
(
"m %d fuset %d %d %d
\n
"
,
m
,
fuse_t
,
t
,
topk_id
);
if
(
t
>=
tokens
)
if
(
t
>=
tokens
)
{
{
continue
;
continue
;
}
}
const
int
topk_id
=
(
fuse_t
&
0xff000000
)
>>
24
;
const
int
e
=
expert_ids
(
m
/
sorted_tile_size
);
const
int
e
=
expert_ids
(
m
/
sorted_tile_size
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
cde_element_op
(
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
m
,
topk_id
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
));
cde_element_op
(
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
m
,
topk_id
,
n
),
d0_t_n
(
t
,
n
),
d1_e_n
(
e
,
n
));
printf
(
"m %d fuset %d %d %d %f %f
\n
"
,
m
,
topk_id
,
t
,
n
,
e_t_n_host_result
(
t
,
topk_id
,
n
),
c_t_k_n
(
m
,
topk_id
,
n
));
}
}
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp
View file @
1078d229
...
@@ -535,6 +535,7 @@ struct GridwiseMoeGemmGather
...
@@ -535,6 +535,7 @@ struct GridwiseMoeGemmGather
struct
Problem
struct
Problem
{
{
__host__
__device__
Problem
(
index_t
NumTokens_
,
__host__
__device__
Problem
(
index_t
NumTokens_
,
index_t
TopK_
,
index_t
M_
,
index_t
M_
,
index_t
N_
,
index_t
N_
,
index_t
K_
,
index_t
K_
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp
View file @
1078d229
...
@@ -74,6 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -74,6 +74,8 @@ struct ReferenceMoeGemm : public device::BaseOperator
AccDataType
v_acc
{
0
};
AccDataType
v_acc
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeA
v_a
{
0
};
ComputeTypeB
v_b
{
0
};
ComputeTypeB
v_b
{
0
};
if
(
m
>=
max_sorted_num
)
return
;
const
int
t
=
arg
.
sorted_token_ids_
(
m
)
&
0xffffff
;
const
int
t
=
arg
.
sorted_token_ids_
(
m
)
&
0xffffff
;
const
int
topk_id
=
(
arg
.
sorted_token_ids_
(
m
)
&
0xff000000
)
>>
24
;
const
int
topk_id
=
(
arg
.
sorted_token_ids_
(
m
)
&
0xff000000
)
>>
24
;
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
const
int
e
=
arg
.
expert_ids_
(
m
/
arg
.
sorted_tile_size_
);
...
@@ -105,17 +107,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
...
@@ -105,17 +107,16 @@ struct ReferenceMoeGemm : public device::BaseOperator
v_acc
+=
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
}
CDataType
v_c
{
0
};
CDataType
v_c
{
0
};
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_t_k_n_
(
t
,
topk_id
,
n
)
=
v_c
;
arg
.
c_t_k_n_
(
t
,
topk_id
,
n
)
=
v_c
;
}
};
};
make_ParallelTensorFunctor
(
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
arg
.
sorted_t
ile_size_
,
arg
.
c_t_k_n_
.
mDesc
.
GetLengths
()[
2
])(
f_mk_kn_mn
,
arg
.
sorted_t
oken_ids_
.
GetLengths
()[
0
]
,
arg
.
c_t_k_n_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment