Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
56cc306d
Commit
56cc306d
authored
Feb 15, 2025
by
coderfeli
Browse files
fix perf calc
parent
7572a691
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
13 deletions
+20
-13
example/65_gemm_multiply_multiply/moe_gemm1.cpp
example/65_gemm_multiply_multiply/moe_gemm1.cpp
+16
-10
example/65_gemm_multiply_multiply/moe_gemm2.cpp
example/65_gemm_multiply_multiply/moe_gemm2.cpp
+4
-3
No files found.
example/65_gemm_multiply_multiply/moe_gemm1.cpp
View file @
56cc306d
...
...
@@ -132,7 +132,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
ck
::
index_t
MPerBlock
=
32
;
static
constexpr
ck
::
index_t
MPerBlock
=
128
;
static
constexpr
ck
::
index_t
BLOCKSIZE
=
256
;
static
constexpr
ck
::
index_t
NPerBlock
=
128
;
static
constexpr
ck
::
index_t
MNPerXDL
=
32
;
...
...
@@ -194,7 +194,7 @@ int main(int argc, char* argv[])
ck
::
index_t
N
=
6144
;
ck
::
index_t
K
=
8192
;
ck
::
index_t
experts
=
8
;
ck
::
index_t
sorted_tile_num
=
9
;
ck
::
index_t
sorted_tile_num
=
8
;
ck
::
index_t
valid_tile_num
=
8
;
ck
::
index_t
sorted_size
=
sorted_tile_num
*
MPerBlock
;
ck
::
index_t
valid_size
=
valid_tile_num
*
MPerBlock
;
...
...
@@ -207,13 +207,14 @@ int main(int argc, char* argv[])
{
// use default case
}
else
if
(
argc
==
6
)
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
N
=
std
::
stoi
(
argv
[
4
]);
K
=
std
::
stoi
(
argv
[
5
]);
tokens
=
std
::
stoi
(
argv
[
6
]);
}
else
{
...
...
@@ -221,10 +222,15 @@ int main(int argc, char* argv[])
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: N, K
\n
"
);
"arg4 to 5: N, K
, tokens
\n
"
);
exit
(
0
);
}
if
(
tokens
*
topk
>
valid_size
)
{
printf
(
"err config, tokens * topk > valid_size
\n
"
);
exit
(
-
1
);
}
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideE
=
N
;
...
...
@@ -235,7 +241,7 @@ int main(int argc, char* argv[])
// const ck::index_t experts = 8;
Tensor
<
ck
::
index_t
>
expert_ids
(
HostTensorDescriptor
({
experts
},
{
1
}));
Tensor
<
ck
::
index_t
>
expert_ids
(
HostTensorDescriptor
({
sorted_tile_num
},
{
1
}));
Tensor
<
ck
::
index_t
>
sorted_token_ids
(
HostTensorDescriptor
({
sorted_size
},
{
1
}));
Tensor
<
ck
::
index_t
>
max_token_id
(
HostTensorDescriptor
({
1
}));
max_token_id
.
mData
[
0
]
=
valid_size
;
...
...
@@ -246,7 +252,7 @@ int main(int argc, char* argv[])
int
tokenid
=
0
;
// sorted_token_ids.mData[0] = 0;
for
(
int
i
=
0
;
i
<
sorted_size
;
i
++
)
{
int
tile_off
=
i
%
valid_size
;
int
tile_off
=
i
%
MPerBlock
;
if
(
tile_off
<
token_per_tile
)
{
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
tokens
)
|
((
tokenid
/
tokens
)
<<
24
);
...
...
@@ -278,9 +284,9 @@ int main(int argc, char* argv[])
case
0
:
break
;
case
1
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
1
,
3
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
1
,
3
});
b0_e_n_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
d0_t_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_e_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
break
;
case
2
:
a0_t_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
A0DataType
>
{});
...
...
@@ -358,7 +364,7 @@ int main(int argc, char* argv[])
if
(
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
valid_tile_num
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
tokens
*
topk
*
N
*
K
;
std
::
size_t
num_btype
=
sizeof
(
A0DataType
)
*
valid_tile_num
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
*
experts
+
sizeof
(
EDataType
)
*
valid_tile_num
*
N
;
...
...
example/65_gemm_multiply_multiply/moe_gemm2.cpp
View file @
56cc306d
...
...
@@ -200,13 +200,14 @@ int main(int argc, char* argv[])
{
// use default case
}
else
if
(
argc
==
6
)
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
N
=
std
::
stoi
(
argv
[
4
]);
K
=
std
::
stoi
(
argv
[
5
]);
tokens
=
std
::
stoi
(
argv
[
6
]);
}
else
{
...
...
@@ -214,7 +215,7 @@ int main(int argc, char* argv[])
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to
5
: N, K
\n
"
);
"arg4 to
6
: N, K
, tokens
\n
"
);
exit
(
0
);
}
...
...
@@ -244,7 +245,7 @@ int main(int argc, char* argv[])
int
tokenid
=
0
;
// sorted_token_ids.mData[0] = 0;
for
(
int
i
=
0
;
i
<
sorted_size
;
i
++
)
{
int
tile_off
=
i
%
valid_size
;
int
tile_off
=
i
%
MPerBlock
;
if
(
tile_off
<
token_per_tile
)
{
sorted_token_ids
.
mData
[
i
]
=
(
tokenid
%
tokens
)
|
((
tokenid
/
tokens
)
<<
24
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment