Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5311d1b3
Unverified
Commit
5311d1b3
authored
Oct 03, 2023
by
zjing14
Committed by
GitHub
Oct 03, 2023
Browse files
changed test for grouped_gemm to be random (#959)
Co-authored-by:
Jing Zhang
<
jizha@amd.com
>
parent
aa46039f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
25 additions
and
36 deletions
+25
-36
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
.../21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
+4
-5
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
...nt_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
+4
-6
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
+4
-5
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
+4
-5
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+3
-5
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
+3
-5
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
+3
-5
No files found.
client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp
View file @
5311d1b3
...
@@ -60,14 +60,13 @@ int main()
...
@@ -60,14 +60,13 @@ int main()
int
sum_of_m
=
0
;
int
sum_of_m
=
0
;
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
const
int
group_count
=
16
;
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
{
Ns
.
push_back
(
768
);
Ms
.
push_back
(
256
+
256
*
i
);
Ks
.
push_back
(
4608
);
Ns
.
push_back
(
128
+
128
*
i
);
Ks
.
push_back
(
128
+
64
*
i
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp
View file @
5311d1b3
...
@@ -57,15 +57,13 @@ int main()
...
@@ -57,15 +57,13 @@ int main()
int
sum_of_m
=
0
;
int
sum_of_m
=
0
;
// Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
const
int
group_count
=
16
;
Ms
=
{
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
};
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
{
Ns
.
push_back
(
768
);
Ms
.
push_back
(
256
+
256
*
i
);
Ks
.
push_back
(
4608
);
Ns
.
push_back
(
128
+
128
*
i
);
Ks
.
push_back
(
128
+
64
*
i
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp
View file @
5311d1b3
...
@@ -58,14 +58,13 @@ int main()
...
@@ -58,14 +58,13 @@ int main()
int
sum_of_m
=
0
;
int
sum_of_m
=
0
;
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
const
int
group_count
=
16
;
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
{
Ns
.
push_back
(
768
);
Ms
.
push_back
(
256
+
256
*
i
);
Ks
.
push_back
(
4608
);
Ns
.
push_back
(
128
+
128
*
i
);
Ks
.
push_back
(
128
+
64
*
i
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
...
...
client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp
View file @
5311d1b3
...
@@ -58,14 +58,13 @@ int main()
...
@@ -58,14 +58,13 @@ int main()
int
sum_of_m
=
0
;
int
sum_of_m
=
0
;
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
const
int
group_count
=
16
;
int
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
for
(
int
i
=
0
;
i
<
group_count
;
++
i
)
{
{
Ns
.
push_back
(
768
);
Ms
.
push_back
(
256
+
256
*
i
);
Ks
.
push_back
(
4608
);
Ns
.
push_back
(
128
+
128
*
i
);
Ks
.
push_back
(
128
+
64
*
i
);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
5311d1b3
...
@@ -296,13 +296,11 @@ int main(int argc, char* argv[])
...
@@ -296,13 +296,11 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
View file @
5311d1b3
...
@@ -297,13 +297,11 @@ int main(int argc, char* argv[])
...
@@ -297,13 +297,11 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
View file @
5311d1b3
...
@@ -66,13 +66,11 @@ int main(int argc, char* argv[])
...
@@ -66,13 +66,11 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ns
.
push_back
(
128
+
128
*
i
);
problem_size
.
Ks
.
push_back
(
128
+
64
*
i
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment