Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b621b28f
Commit
b621b28f
authored
May 25, 2023
by
Paul
Browse files
Simplify fuse_ck
parent
fee874e3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
102 deletions
+6
-102
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+5
-101
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+1
-1
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
b621b28f
...
...
@@ -7,8 +7,7 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK_GEMM
);
struct
module
;
...
...
@@ -51,43 +50,6 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_int8
{
operation
op
=
make_op
(
"quant_dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_int8"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
check_gemm_shape
(
input
);
auto
r
=
op
.
compute_shape
({
a
,
b
});
if
(
mods
.
empty
())
return
r
.
with_type
(
migraphx
::
shape
::
int8_type
);
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_int8
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
...
...
@@ -107,7 +69,7 @@ struct find_ck_gemm_pointwise
auto
matcher
()
const
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
,
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
}
...
...
@@ -123,7 +85,7 @@ struct find_ck_gemm_pointwise
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
assert
(
gemm_it
!=
inputs
.
end
());
if
(
ins
->
get_shape
().
type
()
!=
shape
::
half_
type
)
if
(
not
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
())
)
return
;
if
(
gemm_idx
!=
0
)
{
...
...
@@ -140,49 +102,7 @@ struct find_ck_gemm_pointwise
inputs
.
erase
(
gemm_it
);
inputs
.
insert
(
inputs
.
begin
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{},
inputs
,
{
pm
});
}
};
struct
find_ck_gemm_pointwise_int8
{
// Find a gemm followed by a pointwise operation.
auto
matcher
()
const
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
next_ins
=
std
::
next
(
ins
);
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
inputs
=
ins
->
inputs
();
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
assert
(
gemm_it
!=
inputs
.
end
());
if
(
gemm_idx
!=
0
)
{
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
"_0"
,
gemm_param
->
get_shape
());
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
"_0"
,
first_param
->
get_shape
());
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
gemm_param
);
}
inputs
.
erase
(
gemm_it
);
inputs
.
insert
(
inputs
.
begin
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_int8
{},
inputs
,
{
pm
});
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{
gemm_ins
->
get_operator
()},
inputs
,
{
pm
});
}
};
...
...
@@ -197,30 +117,14 @@ struct find_ck_gemm
}
};
struct
find_ck_gemm_int8
{
auto
matcher
()
const
{
return
match
::
name
(
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_int8
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
if
(
not
enabled
(
MIGRAPHX_
DIS
ABLE_CK_GEMM
_FUSION
{}))
if
(
enabled
(
MIGRAPHX_
EN
ABLE_CK_GEMM
{}))
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise_int8
{});
}
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
{
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_int8
{});
}
}
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
b621b28f
...
...
@@ -222,7 +222,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
,
"ck_gemm_int8"
,
"gpu::ck_gemm_int8"
};
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment