Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
08546656
Commit
08546656
authored
Apr 07, 2023
by
Alan Turner
Browse files
Formatting
parent
cc30b7c1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
17 deletions
+46
-17
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+46
-17
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
08546656
...
@@ -299,29 +299,57 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -299,29 +299,57 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
transb
=
transposed_matrix
(
b_shape
);
auto
transb
=
transposed_matrix
(
b_shape
);
std
::
string
instance_str1
;
std
::
string
instance_str1
;
std
::
string
instance_str2
;
std
::
string
instance_str2
;
if
(
transa
and
not
transb
)
if
(
transa
and
not
transb
)
{
{
instance_str1
=
"DeviceGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str1
=
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"
;
"DeviceGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, "
"2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, "
"4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, "
" S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>"
;
}
}
else
if
(
transa
and
transb
)
else
if
(
transa
and
transb
)
{
{
instance_str1
=
"DeviceGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str1
=
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"
;
"DeviceGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, "
"2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, "
"4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, "
" S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>"
;
}
}
else
if
(
not
transa
and
not
transb
)
else
if
(
not
transa
and
not
transb
)
{
{
instance_str1
=
"DeviceGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str1
=
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"
;
"DeviceGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, "
"3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, "
"4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, "
" S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>"
;
}
}
else
else
{
{
instance_str1
=
"DeviceGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str1
=
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"
;
"DeviceGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "
;
instance_str2
=
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, "
"3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, "
"4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, "
" S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>"
;
}
}
auto
rank
=
a_shape
.
lens
().
size
();
auto
rank
=
a_shape
.
lens
().
size
();
...
@@ -361,14 +389,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -361,14 +389,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
// gemm_type += "N";
// gemm_type += "N";
// if (int_div_ceil(k, k_per_block) * k_per_block - k != 0)
// if (int_div_ceil(k, k_per_block) * k_per_block - k != 0)
// gemm_type += "K";
// gemm_type += "K";
if
((
int_div_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
or
(
int_div_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
))
if
((
int_div_ceil
(
m
,
m_per_block
)
*
m_per_block
-
m
!=
0
)
or
(
int_div_ceil
(
n
,
n_per_block
)
*
n_per_block
-
n
!=
0
))
gemm_type
=
"MNPadding"
;
gemm_type
=
"MNPadding"
;
else
else
gemm_type
=
"Default"
;
gemm_type
=
"Default"
;
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
std
::
string
padding_str
=
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
;
std
::
string
padding_str
=
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
;
std
::
cout
<<
padding_str
<<
std
::
endl
;
std
::
cout
<<
padding_str
<<
std
::
endl
;
//std::exit(0);
//
std::exit(0);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
128
)
*
int_div_ceil
(
n
,
128
);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
128
)
*
int_div_ceil
(
n
,
128
);
;
// ip.get_grid_size(config);
;
// ip.get_grid_size(config);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment