Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0a463c1e
Commit
0a463c1e
authored
Aug 29, 2023
by
Alan Turner
Browse files
Formatting
parent
8ab0b22e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
41 deletions
+41
-41
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+3
-7
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+38
-34
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
0a463c1e
...
@@ -73,7 +73,6 @@ struct ck_gemm
...
@@ -73,7 +73,6 @@ struct ck_gemm
};
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_softmax_gemm
struct
ck_gemm_softmax_gemm
{
{
operation
op
=
make_op
(
"dot"
);
operation
op
=
make_op
(
"dot"
);
...
@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
...
@@ -107,10 +106,7 @@ struct ck_gemm_softmax_gemm
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
{
return
contains
({
shape
::
half_type
},
t
);
}
};
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_softmax_gemm
);
MIGRAPHX_REGISTER_OP
(
ck_gemm_softmax_gemm
);
...
@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -140,7 +136,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
// To-do: Investigate a more precise strategy
return
true
;
//k <= 2048;
return
true
;
//
k <= 2048;
}
}
struct
find_ck_gemm_softmax_gemm
struct
find_ck_gemm_softmax_gemm
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
0a463c1e
...
@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -266,7 +266,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
static
bool
standard_batch
(
const
shape
&
s
)
static
bool
standard_batch
(
const
shape
&
s
)
{
{
...
@@ -293,8 +296,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -293,8 +296,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
b_strides
.
begin
(),
b_strides
.
end
()
-
2
,
[](
auto
stride
)
{
return
stride
==
0
;
});
}
}
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
ck
::
host
::
device_batched_gemm_softmax_gemm
::
Problem
const
value
&
v
)
const
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
a_shape
=
inputs
[
0
];
const
auto
&
b_shape
=
inputs
[
1
];
const
auto
&
b_shape
=
inputs
[
1
];
...
@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -403,7 +406,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"
;
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
}
...
@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -423,8 +427,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
{
std
::
vector
<
shape
>
gemm_shapes
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
().
with_type
(
shapes
[
0
].
type
())};
shapes
[
0
],
shapes
[
1
],
shapes
.
back
().
with_type
(
shapes
[
0
].
type
())};
std
::
cout
<<
"gpu::ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
std
::
cout
<<
"gpu::ck_gemm_softmax_gemm: "
<<
std
::
endl
;
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
}
m
.
replace_instruction
(
ins2
,
code_object
,
ins2
->
inputs
());
m
.
replace_instruction
(
ins2
,
code_object
,
ins2
->
inputs
());
}};
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment