Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7c636a51
Commit
7c636a51
authored
Mar 28, 2023
by
Alan Turner
Browse files
Formatting
parent
b9f1b198
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
20 deletions
+19
-20
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+7
-6
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+11
-8
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+1
-6
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
7c636a51
...
@@ -349,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -349,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
return
true
;
/* get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
return
true
;
/* get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
})};
})};
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
if
(
v
.
contains
(
"post"
))
...
@@ -374,13 +374,14 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -374,13 +374,14 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
gemm_type
+=
"Padding"
;
gemm_type
+=
"Padding"
;
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
128
)
*
int_div_ceil
(
n
,
128
);;
//ip.get_grid_size(config);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
128
)
*
int_div_ceil
(
n
,
128
);
;
// ip.get_grid_size(config);
hip_compile_options
options
;
hip_compile_options
options
;
auto
block_size
=
256
;
//ip.get_block_size();
auto
block_size
=
256
;
//
ip.get_block_size();
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
//auto new_inputs = inputs;
//
auto new_inputs = inputs;
auto
new_inputs
=
inputs
;
auto
new_inputs
=
inputs
;
// auto out_s = inputs.back();
// auto out_s = inputs.back();
// new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()};
// new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
7c636a51
...
@@ -46,7 +46,9 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>
...
@@ -46,7 +46,9 @@ using ck_transposeb = decltype(make_shape(ck_transposeb_dims(get_shape_c<Tensor>
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
strides
)));
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
strides
)));
template
<
class
...
Xs
>
template
<
class
...
Xs
>
constexpr
void
noop
(
Xs
...)
{}
constexpr
void
noop
(
Xs
...)
{
}
template
<
class
G
,
class
E
,
class
A
,
class
B
,
class
...
Ds
>
template
<
class
G
,
class
E
,
class
A
,
class
B
,
class
...
Ds
>
__device__
void
ck_gemm_matrix
(
E
e
,
A
a
,
B
b
,
Ds
...
ds
)
__device__
void
ck_gemm_matrix
(
E
e
,
A
a
,
B
b
,
Ds
...
ds
)
...
@@ -78,17 +80,17 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
...
@@ -78,17 +80,17 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
// GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
// GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
// noop(a, b, e, M, N, K, a_grid_desc_k0_m_k1, a_grid_desc_k0_m0_m1_k1, ds...);
// noop(a, b, e, M, N, K, a_grid_desc_k0_m_k1, a_grid_desc_k0_m0_m1_k1, ds...);
//////////////////////////
//////////////////////////
// constexpr const G gemm{};
// constexpr const G gemm{};
// constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
// constexpr const auto a_grid_desc_m_k =
// constexpr const auto b_grid_desc_n_k =
// gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>()); constexpr const auto
// b_grid_desc_n_k =
// gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
// gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
// constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
// constexpr const auto e_grid_desc_m_n =
// constexpr const auto ds_grid_desc_m_n =
// gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>()); constexpr const auto
// ds_grid_desc_m_n =
// ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
// ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
// constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
// constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
...
@@ -112,7 +114,8 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
...
@@ -112,7 +114,8 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
// __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// constexpr const bool HasMainKBlockLoop =
// constexpr const bool HasMainKBlockLoop =
// GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
// GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{})
// *
// a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
// a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
// GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
// GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
// to_ck_const_pointer(b.data()),
// to_ck_const_pointer(b.data()),
...
...
test/onnx/gen_onnx.py
View file @
7c636a51
...
@@ -4058,12 +4058,7 @@ def int8_gemm_verify():
...
@@ -4058,12 +4058,7 @@ def int8_gemm_verify():
outputs
=
[
'x'
],
outputs
=
[
'x'
],
)
)
convert
=
onnx
.
helper
.
make_node
(
convert
=
onnx
.
helper
.
make_node
(
'Cast'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
to
=
6
)
'Cast'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
to
=
6
)
return
([
node
,
convert
],
[
m1
,
m2
],
[
y
])
return
([
node
,
convert
],
[
m1
,
m2
],
[
y
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment