Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1f106ca7
Commit
1f106ca7
authored
Jul 13, 2023
by
turneram
Browse files
Add envvars for AB testing
parent
f1c8e6c9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
18 deletions
+101
-18
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+28
-8
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+6
-5
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+14
-5
tools/gemm_perf.py
tools/gemm_perf.py
+53
-0
No files found.
src/rewrite_quantization.cpp
View file @
1f106ca7
...
@@ -32,6 +32,8 @@
...
@@ -32,6 +32,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_BROADCAST_Q
);
void
apply_quantizelinear
(
module
&
m
,
instruction_ref
ins
)
void
apply_quantizelinear
(
module
&
m
,
instruction_ref
ins
)
{
{
assert
(
ins
->
name
()
==
"quantizelinear"
);
assert
(
ins
->
name
()
==
"quantizelinear"
);
...
@@ -61,15 +63,33 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -61,15 +63,33 @@ void apply_quantizelinear(module& m, instruction_ref ins)
max_quant
=
qt
.
max
();
max_quant
=
qt
.
max
();
min_quant
=
qt
.
min
();
min_quant
=
qt
.
min
();
});
});
auto
s
=
add_zero_point
->
get_shape
();
if
(
enabled
(
MIGRAPHX_BROADCAST_Q
{}))
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
{
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
s
=
add_zero_point
->
get_shape
();
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
min_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
min_quant
}});
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
max_arg
=
m
.
add_literal
(
literal
{
shape
{
s
.
type
()},
{
max_quant
}});
auto
min_mbcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
min_arg
);
auto
max_mbcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
s
.
lens
()}}),
max_arg
);
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_mbcast
,
max_mbcast
);
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
else
{
auto
s
=
add_zero_point
->
get_shape
();
std
::
vector
<
int
>
min_data
(
s
.
elements
(),
min_quant
);
std
::
vector
<
int
>
max_data
(
s
.
elements
(),
max_quant
);
auto
min_arg
=
m
.
add_literal
(
literal
(
s
,
min_data
));
auto
max_arg
=
m
.
add_literal
(
literal
(
s
,
max_data
));
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
auto
saturate
=
m
.
insert_instruction
(
ins
,
make_op
(
"clip"
),
add_zero_point
,
min_arg
,
max_arg
);
m
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
ins
->
get_shape
().
type
()}}),
saturate
);
}
}
}
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
...
...
src/simplify_algebra.cpp
View file @
1f106ca7
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
...
@@ -1095,8 +1095,9 @@ MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins)
};
};
};
};
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
dots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"dot"
));
auto
qdots
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"quant_dot"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
auto
convs
=
std
::
count_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
pred
(
"convolution"
));
return
(
dots
>=
2
or
convs
>=
2
);
return
(
dots
>=
2
or
convs
>=
2
or
qdots
>=
2
);
}
}
struct
find_conv_dot_horiz_fusion
struct
find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1110,7 +1111,7 @@ struct find_conv_dot_horiz_fusion
auto
pred
=
[](
auto
i
,
auto
j
)
{
auto
pred
=
[](
auto
i
,
auto
j
)
{
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
if
(
i
->
get_operator
()
!=
j
->
get_operator
())
return
false
;
return
false
;
if
(
not
contains
({
"dot"
,
"convolution"
},
i
->
name
()))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
i
->
name
()))
return
true
;
return
true
;
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
x
=
i
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
auto
y
=
j
->
inputs
()[
1
]
->
get_shape
().
lens
();
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1118,7 +1119,7 @@ struct find_conv_dot_horiz_fusion
return
false
;
return
false
;
// Check that non-axes match
// Check that non-axes match
int
axis
=
1
;
int
axis
=
1
;
if
(
i
->
name
()
==
"dot"
)
if
(
i
->
name
()
==
"dot"
or
i
->
name
()
==
"quant_dot"
)
{
{
axis
=
x
.
size
()
-
1
;
axis
=
x
.
size
()
-
1
;
}
}
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1129,7 +1130,7 @@ struct find_conv_dot_horiz_fusion
if
(
std
::
distance
(
start
,
last
)
<
2
)
if
(
std
::
distance
(
start
,
last
)
<
2
)
return
;
return
;
auto
&&
name
=
(
*
start
)
->
name
();
auto
&&
name
=
(
*
start
)
->
name
();
if
(
not
contains
({
"dot"
,
"convolution"
},
name
))
if
(
not
contains
({
"quant_dot"
,
"dot"
,
"convolution"
},
name
))
return
;
return
;
auto
op
=
(
*
start
)
->
get_operator
();
auto
op
=
(
*
start
)
->
get_operator
();
int
group
=
1
;
int
group
=
1
;
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
...
@@ -1144,7 +1145,7 @@ struct find_conv_dot_horiz_fusion
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
start
,
last
,
std
::
back_inserter
(
args
),
[
&
](
auto
x
)
{
return
x
->
inputs
().
at
(
1
);
});
int
axis
=
1
;
int
axis
=
1
;
int
concat_axis
=
0
;
int
concat_axis
=
0
;
if
(
name
==
"dot"
)
if
(
name
==
"dot"
or
name
==
"quant_dot"
)
{
{
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
axis
=
int
(
args
.
front
()
->
get_shape
().
lens
().
size
()
-
1
);
concat_axis
=
axis
;
concat_axis
=
axis
;
...
...
src/targets/gpu/fuse_ck.cpp
View file @
1f106ca7
...
@@ -29,6 +29,10 @@
...
@@ -29,6 +29,10 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_USE_LARGE_K
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_CK_FUSION
);
struct
module
;
struct
module
;
namespace
gpu
{
namespace
gpu
{
...
@@ -72,7 +76,7 @@ namespace {
...
@@ -72,7 +76,7 @@ namespace {
bool
is_ck_supported_type
(
shape
::
type_t
t
)
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
{
return
contains
({
shape
::
half_type
,
shape
::
int8_type
},
t
);
return
contains
({
shape
::
half_type
,
shape
::
int8_type
,
shape
::
int32_type
},
t
);
}
}
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
...
@@ -89,7 +93,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -89,7 +93,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Integer gemms must be divisible by 4 in ck
// Integer gemms must be divisible by 4 in ck
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
if
(
contains
({
shape
::
int8_type
,
shape
::
int32_type
},
ins
->
get_shape
().
type
()))
{
{
if
(
m
%
4
!=
0
)
if
(
m
!=
1
and
m
%
4
!=
0
)
return
false
;
return
false
;
if
(
n
%
4
!=
0
)
if
(
n
%
4
!=
0
)
return
false
;
return
false
;
...
@@ -99,7 +103,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -99,7 +103,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
// To-do: Investigate a more precise strategy
return
k
<=
2048
;
return
k
<=
2048
or
enabled
(
MIGRAPHX_USE_LARGE_K
{})
;
}
}
struct
find_ck_gemm_pointwise
struct
find_ck_gemm_pointwise
...
@@ -130,6 +134,10 @@ struct find_ck_gemm_pointwise
...
@@ -130,6 +134,10 @@ struct find_ck_gemm_pointwise
return
not
is_ck_supported_type
(
input
->
get_shape
().
type
());
return
not
is_ck_supported_type
(
input
->
get_shape
().
type
());
}))
}))
return
;
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
input
)
{
return
not
input
->
inputs
().
empty
()
and
input
->
inputs
().
front
()
->
name
()
==
"capture"
;
}))
return
;
assert
(
gemm_it
!=
inputs
.
end
());
assert
(
gemm_it
!=
inputs
.
end
());
if
(
gemm_idx
!=
0
)
if
(
gemm_idx
!=
0
)
{
{
...
@@ -152,7 +160,7 @@ struct find_ck_gemm_pointwise
...
@@ -152,7 +160,7 @@ struct find_ck_gemm_pointwise
struct
find_ck_gemm
struct
find_ck_gemm
{
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"dot"
,
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
...
@@ -165,7 +173,8 @@ struct find_ck_gemm
...
@@ -165,7 +173,8 @@ struct find_ck_gemm
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_FUSION
{}))
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
}
...
...
tools/gemm_perf.py
0 → 100644
View file @
1f106ca7
import
subprocess
,
csv
,
re
def
get_device_name
():
out
=
subprocess
.
run
(
"rocminfo"
,
capture_output
=
True
,
check
=
True
,
shell
=
True
)
matches
=
re
.
findall
(
"gfx\d*[a-z]*"
,
str
(
out
.
stdout
))
return
matches
[
0
]
def
run_perf
(
model
,
batch_size
,
int8
=
False
,
use_ck
=
False
,
use_large_k
=
False
,
disable_fusion
=
False
):
env_vars
=
""
if
use_ck
:
env_vars
+=
"MIGRAPHX_ENABLE_CK=1 "
if
use_large_k
:
env_vars
+=
"MIGRAPHX_USE_LARGE_K=1 "
if
disable_fusion
:
env_vars
+=
"MIGRAPHX_DISABLE_CK_FUSION=1 "
int8_str
=
"--int8"
if
int8
else
""
cmd
=
"{env_vars} ../build/bin/driver perf {model} --fill1 input_ids --input-dim @input_ids {batch_size} 384 --batch {batch_size} --fp16 {int8} --exhaustive-tune"
.
format
(
env_vars
=
env_vars
,
model
=
model
,
batch_size
=
str
(
batch_size
),
int8
=
int8_str
)
out
=
subprocess
.
run
(
cmd
,
capture_output
=
True
,
check
=
True
,
shell
=
True
)
summary
=
re
.
findall
(
"Summary.*"
,
str
(
out
.
stdout
))[
0
].
replace
(
"
\\
n"
,
"
\n
"
)
total_time
=
re
.
findall
(
"Total time: \d+\.\d*"
,
summary
)[
0
]
total_time
=
total_time
.
replace
(
"Total time: "
,
""
)
print
(
summary
)
print
(
total_time
)
with
open
(
"summaries.txt"
,
"w+"
)
as
f
:
f
.
write
(
cmd
+
"
\n
"
)
f
.
write
(
summary
+
"
\n\n
"
)
# run model with:
# RocBlas
# Get gemm info
# CK
# With fusions
# Without fusions
if
__name__
==
"__main__"
:
device_id
=
get_device_name
()
model
=
"/code/bert_base_cased_1_fp16_gpu.onnx"
run_perf
(
model
,
1
,
True
,
True
,
True
,
True
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment