Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ffe2c0cc
Commit
ffe2c0cc
authored
Dec 14, 2022
by
Alan Turner
Browse files
Formatting
parent
4b96da8d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
19 deletions
+17
-19
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+16
-18
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+1
-1
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
ffe2c0cc
...
@@ -246,21 +246,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -246,21 +246,19 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static
std
::
size_t
get_batch_count
(
const
shape
&
s
)
static
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
return
std
::
accumulate
(
s
.
lens
().
rend
(),
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
static
void
fold_batch_dims
(
shape
&
s
)
static
void
fold_batch_dims
(
shape
&
s
)
{
{
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
if
(
lens
.
size
()
<=
2
)
return
;
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
...
@@ -269,11 +267,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -269,11 +267,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static
void
remove_batch_dims
(
shape
&
s
)
static
void
remove_batch_dims
(
shape
&
s
)
{
{
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
if
(
lens
.
size
()
<=
2
)
return
;
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
...
@@ -284,15 +282,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -284,15 +282,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
b_shape
=
inputs
[
1
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
c_shape
=
inputs
.
back
();
auto
rank
=
a_shape
.
lens
().
size
();
auto
rank
=
a_shape
.
lens
().
size
();
auto
b_strides
=
b_shape
.
strides
();
auto
b_strides
=
b_shape
.
strides
();
bool
can_fold_batch
=
rank
>=
3
and
b_strides
[
rank
-
3
]
==
0
;
bool
can_fold_batch
=
rank
>=
3
and
b_strides
[
rank
-
3
]
==
0
;
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
m
=
can_fold_batch
?
m
*
batch_count
:
m
;
m
=
can_fold_batch
?
m
*
batch_count
:
m
;
auto
n
=
c_shape
.
lens
().
back
();
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
std
::
array
<
char
,
3
>
keys
{
'M'
,
'N'
,
'K'
};
std
::
array
<
char
,
3
>
keys
{
'M'
,
'N'
,
'K'
};
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
...
@@ -332,7 +330,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -332,7 +330,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options
.
output
=
c_shape
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
options
.
virtual_inputs
=
inputs
;
if
(
can_fold_batch
)
if
(
can_fold_batch
)
{
{
auto
vinputs
=
inputs
;
auto
vinputs
=
inputs
;
fold_batch_dims
(
vinputs
[
0
]);
fold_batch_dims
(
vinputs
[
0
]);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
ffe2c0cc
...
@@ -53,7 +53,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
...
@@ -53,7 +53,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_tensor
<
A
>
());
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_tensor
<
A
>
());
constexpr
const
auto
b_grid_desc_n_k
=
constexpr
const
auto
b_grid_desc_n_k
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_tensor
<
ck_transposeb
<
B
>>
());
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_tensor
<
ck_transposeb
<
B
>>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
E
>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
E
>
());
constexpr
const
auto
ds_grid_desc_m_n
=
constexpr
const
auto
ds_grid_desc_m_n
=
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
Ds
>
())...);
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
Ds
>
())...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment