Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a463c1e
Commit
0a463c1e
authored
Aug 29, 2023
by
Alan Turner
Browse files
Formatting
parent
8ab0b22e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
48 deletions
+48
-48
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+3
-7
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+38
-34
test/verify/ck_gemm_softmax_gemm.cpp
test/verify/ck_gemm_softmax_gemm.cpp
+7
-7
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
0a463c1e
...
...
@@ -73,7 +73,6 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_softmax_gemm
{
operation
op
=
make_op
(
"dot"
);
...
...
@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_softmax_gemm
);
...
...
@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return
true
;
//k <= 2048;
return
true
;
//
k <= 2048;
}
struct
find_ck_gemm_softmax_gemm
...
...
@@ -163,7 +159,7 @@ struct find_ck_gemm_softmax_gemm
// if (not ck_gemm_softmax_gemm::is_ck_supported_type(gemm1_ins->get_shape().type()))
// return;
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
0a463c1e
...
...
@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
static
bool
standard_batch
(
const
shape
&
s
)
{
...
...
@@ -293,13 +296,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b1_shape
=
inputs
[
2
];
const
auto
&
c_shape
=
inputs
.
back
();
const
auto
&
c_shape
=
inputs
.
back
();
// cppcheck-suppress unreadVariable
auto
rank
=
a_shape
.
ndim
();
...
...
@@ -311,37 +314,37 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto
k
=
a_shape
.
lens
().
back
();
auto
o
=
c_shape
.
lens
().
back
();
const
bool
trans_a
=
transposed_matrix
(
a_shape
);
const
bool
trans_b
=
transposed_matrix
(
b_shape
);
const
bool
trans_a
=
transposed_matrix
(
a_shape
);
const
bool
trans_b
=
transposed_matrix
(
b_shape
);
const
bool
trans_b1
=
transposed_matrix
(
b1_shape
);
const
bool
trans_c
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
bool
trans_c
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
b1_type
=
get_type
(
b1_shape
);
const
auto
c_type
=
get_type
(
c_shape
);
const
auto
scale
=
1.0
f
;
const
auto
c_type
=
get_type
(
c_shape
);
const
auto
scale
=
1.0
f
;
std
::
string
ck_passthrough
=
"ck_passthrough"
;
std
::
string
cde_op
=
ck_passthrough
;
/// update params after adding to jitlib
return
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
{
m
,
n
,
k
,
o
,
trans_a
,
trans_b
,
trans_b1
,
trans_c
,
a_type
,
b_type
,
b1_type
,
c_type
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
scale
};
n
,
k
,
o
,
trans_a
,
trans_b
,
trans_b1
,
trans_c
,
a_type
,
b_type
,
b1_type
,
c_type
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
scale
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
...
...
@@ -350,7 +353,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
c_shape
=
inputs
.
back
();
/// update for 4-arg lookup?
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
4
);
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
4
);
if
(
not
v
.
contains
(
"tuning_value"
))
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
});
auto
batch_count
=
get_batch_count
(
c_shape
);
...
...
@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"
;
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
...
...
@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
().
with_type
(
shapes
[
0
].
type
())};
std
::
cout
<<
"gpu::ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
std
::
cout
<<
"gpu::ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
m
.
replace_instruction
(
ins2
,
code_object
,
ins2
->
inputs
());
}};
...
...
test/verify/ck_gemm_softmax_gemm.cpp
View file @
0a463c1e
...
...
@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
1
*
12
*
256
*
256
;
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
...
...
@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment