Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
99626b4c
Commit
99626b4c
authored
May 19, 2023
by
Alan Turner
Browse files
Enable int8 gemm-pointwise fusion
parent
3d0426e9
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
2996 additions
and
46 deletions
+2996
-46
src/include/migraphx/op/quant_dot.hpp
src/include/migraphx/op/quant_dot.hpp
+1
-1
src/quantization.cpp
src/quantization.cpp
+21
-21
src/rewrite_quantization.cpp
src/rewrite_quantization.cpp
+9
-5
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+144
-6
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+64
-6
src/targets/gpu/jit/ck_gemm_instances.cpp
src/targets/gpu/jit/ck_gemm_instances.cpp
+2746
-1
tools/tune_models.py
tools/tune_models.py
+11
-6
No files found.
src/include/migraphx/op/quant_dot.hpp
View file @
99626b4c
...
@@ -73,7 +73,7 @@ struct quant_dot
...
@@ -73,7 +73,7 @@ struct quant_dot
auto
out_lens
=
a
.
lens
();
auto
out_lens
=
a
.
lens
();
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
return
{
shape
::
int
32
_type
,
out_lens
};
return
{
shape
::
int
8
_type
,
out_lens
};
}
}
};
};
...
...
src/quantization.cpp
View file @
99626b4c
...
@@ -112,28 +112,28 @@ void quantize_int8(program& prog,
...
@@ -112,28 +112,28 @@ void quantize_int8(program& prog,
max_abs_vals
->
resize
(
param_num
,
0.0
f
);
max_abs_vals
->
resize
(
param_num
,
0.0
f
);
// use the calibration data to compute the quantization scale
// use the calibration data to compute the quantization scale
auto
capture_prog
=
prog
;
//
auto capture_prog = prog;
capture_prog
.
compile
(
t
);
//
capture_prog.compile(t);
// use all calibration data to run the program to calculate the
//
//
use all calibration data to run the program to calculate the
// quantization scale and shift
//
//
quantization scale and shift
for
(
auto
&&
arg
:
calibration
)
//
for(auto&& arg : calibration)
{
//
{
parameter_map
m
;
//
parameter_map m;
for
(
auto
&&
x
:
capture_prog
.
get_parameter_shapes
())
//
for(auto&& x : capture_prog.get_parameter_shapes())
{
//
{
if
(
arg
.
count
(
x
.
first
)
>
0
)
//
if(arg.count(x.first) > 0)
{
//
{
assert
(
x
.
second
==
arg
.
at
(
x
.
first
).
get_shape
());
//
assert(x.second == arg.at(x.first).get_shape());
m
[
x
.
first
]
=
t
.
copy_to
(
arg
.
at
(
x
.
first
));
//
m[x.first] = t.copy_to(arg.at(x.first));
}
//
}
else
//
else
{
//
{
m
[
x
.
first
]
=
t
.
allocate
(
x
.
second
);
//
m[x.first] = t.allocate(x.second);
}
//
}
}
//
}
capture_prog
.
eval
(
m
);
//
capture_prog.eval(m);
}
//
}
// print the quantization parameters in only the main module
// print the quantization parameters in only the main module
if
(
enabled
(
MIGRAPHX_INT8_QUANTIZATION_PARAMS
{}))
if
(
enabled
(
MIGRAPHX_INT8_QUANTIZATION_PARAMS
{}))
...
...
src/rewrite_quantization.cpp
View file @
99626b4c
...
@@ -37,10 +37,11 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -37,10 +37,11 @@ void apply_quantizelinear(module& m, instruction_ref ins)
assert
(
ins
->
name
()
==
"quantizelinear"
);
assert
(
ins
->
name
()
==
"quantizelinear"
);
auto
x
=
ins
->
inputs
()[
0
];
auto
x
=
ins
->
inputs
()[
0
];
auto
y_scale
=
ins
->
inputs
()[
1
];
auto
y_scale
=
ins
->
inputs
()[
1
];
auto
scale_type
=
y_scale
->
get_shape
().
type
();
if
(
x
->
get_shape
().
type
()
!=
y_scale
->
get_shape
().
type
())
if
(
x
->
get_shape
().
type
()
!=
y_scale
->
get_shape
().
type
())
{
{
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
hape
::
float
_type
}}),
x
);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
cale
_type
}}),
x
);
}
}
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
div
=
m
.
insert_instruction
(
ins
,
make_op
(
"div"
),
x
,
y_scale
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"round"
),
div
);
auto
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"round"
),
div
);
...
@@ -48,7 +49,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -48,7 +49,7 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if
(
ins
->
inputs
().
size
()
==
3
)
if
(
ins
->
inputs
().
size
()
==
3
)
{
{
auto
zero_point
=
m
.
insert_instruction
(
auto
zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
hape
::
float
_type
}}),
ins
->
inputs
()[
2
]);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
cale
_type
}}),
ins
->
inputs
()[
2
]);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
add_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
add_zero_point
,
zero_point
);
}
}
...
@@ -72,14 +73,15 @@ void apply_quantizelinear(module& m, instruction_ref ins)
...
@@ -72,14 +73,15 @@ void apply_quantizelinear(module& m, instruction_ref ins)
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
void
apply_dequantizelinear
(
module
&
m
,
instruction_ref
ins
)
{
{
assert
(
ins
->
name
()
==
"dequantizelinear"
);
assert
(
ins
->
name
()
==
"dequantizelinear"
);
auto
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
shape
::
float_type
}}),
ins
->
inputs
()[
0
]);
auto
x_scale
=
ins
->
inputs
()[
1
];
auto
x_scale
=
ins
->
inputs
()[
1
];
auto
scale_type
=
x_scale
->
get_shape
().
type
();
auto
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
scale_type
}}),
ins
->
inputs
()[
0
]);
if
(
ins
->
inputs
().
size
()
==
3
)
if
(
ins
->
inputs
().
size
()
==
3
)
{
{
auto
x_zero_point
=
m
.
insert_instruction
(
auto
x_zero_point
=
m
.
insert_instruction
(
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
hape
::
float
_type
}}),
ins
->
inputs
()[
2
]);
ins
,
make_op
(
"convert"
,
{{
"target_type"
,
s
cale
_type
}}),
ins
->
inputs
()[
2
]);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
x
,
x_zero_point
);
x
=
m
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
x
,
x_zero_point
);
}
}
...
@@ -100,6 +102,8 @@ void rewrite_quantization::apply(module& m) const
...
@@ -100,6 +102,8 @@ void rewrite_quantization::apply(module& m) const
apply_dequantizelinear
(
m
,
ins
);
apply_dequantizelinear
(
m
,
ins
);
}
}
}
}
// std::cout << "after rwq: " << std::endl;
// m.debug_print();
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ck.cpp
View file @
99626b4c
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -52,16 +51,53 @@ struct ck_gemm
...
@@ -52,16 +51,53 @@ struct ck_gemm
};
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_int8
{
operation
op
=
make_op
(
"quant_dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_int8"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
for
(
const
auto
&
input
:
inputs
)
check_gemm_shape
(
input
);
auto
r
=
op
.
compute_shape
({
a
,
b
});
if
(
mods
.
empty
())
return
r
.
with_type
(
migraphx
::
shape
::
int8_type
);
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_int8
);
namespace
{
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
!=
"dot"
)
if
(
ins
->
name
()
!=
"dot"
and
ins
->
name
()
!=
"quant_dot"
)
return
false
;
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
back
()
>
2048
)
//
if(a.lens().back() > 2048)
return
false
;
//
return false;
return
true
;
return
true
;
}
}
...
@@ -93,9 +129,9 @@ struct find_ck_gemm_pointwise
...
@@ -93,9 +129,9 @@ struct find_ck_gemm_pointwise
{
{
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
"
.
0"
,
gemm_param
->
get_shape
());
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
"
_
0"
,
gemm_param
->
get_shape
());
auto
new_first_param
=
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
"
.
0"
,
first_param
->
get_shape
());
pm
->
add_parameter
(
names
[
gemm_idx
]
+
"
_
0"
,
first_param
->
get_shape
());
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
first_param
);
...
@@ -108,6 +144,91 @@ struct find_ck_gemm_pointwise
...
@@ -108,6 +144,91 @@ struct find_ck_gemm_pointwise
}
}
};
};
struct
find_ck_gemm_pointwise_int8
{
// Find a gemm followed by a pointwise operation.
auto
matcher
()
const
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
)));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
next_ins
=
std
::
next
(
ins
);
// if (next_ins->name() == "quant_dot")
// {
// std::cout << "\nins: ";
// ins->debug_print();
// std::cout << "\ngemm_ins: ";
// gemm_ins->debug_print();
// std::cout << "\nx_ins: ";
// x_ins->debug_print();
// std::cout << "\nnext: ";
// next_ins->debug_print();
// mpm.get_module().debug_print();
// }
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
inputs
=
ins
->
inputs
();
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
assert
(
gemm_it
!=
inputs
.
end
());
// if(ins->get_shape().type() != shape::half_type)
// return;
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM before: " << std::endl;
// pm->debug_print();
// }
if
(
gemm_idx
!=
0
)
{
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
auto
gemm_param
=
pm
->
get_parameter
(
names
[
gemm_idx
]);
auto
new_gemm_param
=
pm
->
add_parameter
(
names
[
0
]
+
"_0"
,
gemm_param
->
get_shape
());
auto
new_first_param
=
pm
->
add_parameter
(
names
[
gemm_idx
]
+
"_0"
,
first_param
->
get_shape
());
pm
->
replace_instruction
(
gemm_param
,
new_gemm_param
);
pm
->
replace_instruction
(
first_param
,
new_first_param
);
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
gemm_param
);
}
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM after: " << std::endl;
// pm->debug_print();
// }
inputs
.
erase
(
gemm_it
);
inputs
.
insert
(
inputs
.
begin
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
// std::cout << "Next_ins inputs: " << std::endl;
// for (auto& in : next_ins->inputs())
// {
// in->debug_print();
// }
// auto out_shape = compute_shape(ck_gemm_int8{}, inputs, {pm});
// instruction::replace(ins, ck_gemm_int8{}, out_shape.with_type(migraphx::shape::half_type), inputs, {pm});
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_int8
{},
inputs
,
{
pm
});
// std::cout << "Next_ins inputs (post replace): " << std::endl;
// for (auto& in : std::next(ins)->inputs())
// {
// in->debug_print();
// }
// if (next_ins->name() == "softmax" or next_ins->name() == "reshape")
// {
// std::cout << "After replace: " << std::endl;
// mpm.get_module().debug_print();
// }
}
};
struct
find_ck_gemm
struct
find_ck_gemm
{
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
...
@@ -119,14 +240,31 @@ struct find_ck_gemm
...
@@ -119,14 +240,31 @@ struct find_ck_gemm
}
}
};
};
struct
find_ck_gemm_int8
{
auto
matcher
()
const
{
return
match
::
name
(
"quant_dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_int8
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
}
// namespace
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise_int8
{});
}
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
{
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_int8
{});
}
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
99626b4c
...
@@ -122,6 +122,12 @@ struct instance
...
@@ -122,6 +122,12 @@ struct instance
params
[
8
]
=
s
;
params
[
8
]
=
s
;
}
}
void
set_e_type
(
const
std
::
string
&
s
)
{
//assert(params[9] == "ck::Tuple<>");
params
[
9
]
=
s
;
}
void
set_ds_op
(
const
std
::
string
&
s
)
void
set_ds_op
(
const
std
::
string
&
s
)
{
{
assert
(
params
[
12
]
==
"ck_passthrough"
);
assert
(
params
[
12
]
==
"ck_passthrough"
);
...
@@ -134,6 +140,23 @@ struct instance
...
@@ -134,6 +140,23 @@ struct instance
params
[
13
]
=
s
;
params
[
13
]
=
s
;
}
}
void
set_a_scalar_per_vec
(
const
std
::
string
&
s
)
{
params
[
block_size_index
+
14
]
=
s
;
params
[
block_size_index
+
15
]
=
s
;
}
void
set_b_scalar_per_vec
(
const
std
::
string
&
s
)
{
params
[
block_size_index
+
20
]
=
s
;
params
[
block_size_index
+
21
]
=
s
;
}
void
set_c_scalar_per_vec
(
const
std
::
string
&
s
)
{
params
[
params
.
size
()
-
3
]
=
s
;
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
};
};
...
@@ -175,12 +198,20 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
...
@@ -175,12 +198,20 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
{
std
::
cout
<<
"*********** Warning: No CK tuning! for config:"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
}
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
if
(
it
==
tuning
.
end
())
{
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
0
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
1
]
<<
std
::
endl
;
std
::
cout
<<
" "
<<
inputs
[
2
]
<<
std
::
endl
;
std
::
vector
<
std
::
pair
<
float
,
std
::
size_t
>>
w
;
std
::
vector
<
std
::
pair
<
float
,
std
::
size_t
>>
w
;
std
::
transform
(
tuning
.
begin
(),
tuning
.
end
(),
std
::
back_inserter
(
w
),
[
&
](
const
auto
&
p
)
{
std
::
transform
(
tuning
.
begin
(),
tuning
.
end
(),
std
::
back_inserter
(
w
),
[
&
](
const
auto
&
p
)
{
if
(
inputs
.
size
()
<
3
or
p
.
first
.
size
()
<
3
)
if
(
inputs
.
size
()
<
3
or
p
.
first
.
size
()
<
3
)
...
@@ -274,7 +305,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -274,7 +305,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
,
"ck_gemm_int8"
,
"gpu::ck_gemm_int8"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
...
@@ -293,11 +324,27 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -293,11 +324,27 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
k
=
a_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
std
::
array
<
char
,
3
>
keys
{
'M'
,
'N'
,
'K'
};
std
::
array
<
char
,
3
>
keys
{
'M'
,
'N'
,
'K'
};
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
.
with_type
(
a_shape
.
type
())
}));
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
// if (not (get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9]))
// {
// std::cout << get_layout(a_shape) << " - " << x[0] <<std::endl;
// std::cout << get_layout(b_shape) << " - " << x[1] <<std::endl;
// std::cout << get_layout(c_shape) << " - " << x[3] <<std::endl;
// std::cout << get_type(a_shape) << " - " << x[4] <<std::endl;
// std::cout << get_type(b_shape) << " - " << x[5] <<std::endl;
// std::cout << get_type(c_shape) << " - " << x[9] <<std::endl;
// }
/* return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
]
;
get_type
(
b_shape
)
==
x
[
5
];
})};
})};
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
if
(
v
.
contains
(
"post"
))
...
@@ -305,7 +352,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -305,7 +352,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip
.
set_ds_layout
(
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
));
ip
.
set_ds_layout
(
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
));
ip
.
set_ds_type
(
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
));
ip
.
set_ds_type
(
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
));
ip
.
set_ds_op
(
v
.
at
(
"post"
).
to
<
std
::
string
>
());
ip
.
set_ds_op
(
v
.
at
(
"post"
).
to
<
std
::
string
>
());
}
}
ip
.
set_e_type
(
get_type
(
c_shape
));
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
get_type
(
s
)
==
"ck::half_t"
;
}))
{
ip
.
set_c_scalar_per_vec
(
"8"
);
}
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
get_type
(
s
)
==
"float"
;
}))
{
ip
.
set_c_scalar_per_vec
(
"4"
);
}
auto
padding
=
ip
.
get_pad
(
config
);
auto
padding
=
ip
.
get_pad
(
config
);
std
::
string
gemm_type
;
std
::
string
gemm_type
;
...
@@ -349,7 +407,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -349,7 +407,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}});
{
"kernel"
,
options
.
kernel_name
}});
// std::cout << options.kernel_name << ": " << std::endl;
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
...
@@ -370,7 +428,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -370,7 +428,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()
.
with_type
(
shapes
[
0
].
type
())
};
std
::
cout
<<
"ck_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
std
::
cout
<<
"ck_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
}
});
});
...
...
src/targets/gpu/jit/ck_gemm_instances.cpp
View file @
99626b4c
This diff is collapsed.
Click to expand it.
tools/tune_models.py
View file @
99626b4c
...
@@ -30,21 +30,26 @@ def parse_args():
...
@@ -30,21 +30,26 @@ def parse_args():
type
=
str
,
type
=
str
,
help
=
help
=
'Existing tuning JSON. Configs already present will not be re-tuned.'
)
'Existing tuning JSON. Configs already present will not be re-tuned.'
)
parser
.
add_argument
(
"-q"
,
"--quantize_int8"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
def
tune_models
(
models
,
batch_sizes
,
seq_len
,
n
,
existing
):
def
tune_models
(
models
,
batch_sizes
,
seq_len
,
n
,
existing
,
q_int8
):
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M"
)
time_stamp
=
time
.
strftime
(
"%Y_%m_%d_%H_%M"
)
log_file
=
"ck_tuning_{}.log"
.
format
(
time_stamp
)
log_file
=
"ck_tuning_{}.log"
.
format
(
time_stamp
)
json_file
=
"ck_tuning_{}.json"
.
format
(
time_stamp
)
json_file
=
"ck_tuning_{}.json"
.
format
(
time_stamp
)
prec_str
=
"--int8"
if
q_int8
else
""
for
model
in
models
:
for
model
in
models
:
for
batch
in
batch_sizes
:
for
batch
in
batch_sizes
:
params
=
"--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16 "
.
format
(
params
=
"--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16
{}
"
.
format
(
batch
,
batch
)
batch
,
batch
,
prec_str
)
if
"bert"
in
model
:
if
"bert"
in
model
:
params
=
"--fill1 input_ids --input-dim @input_ids {} {} "
.
format
(
params
=
"{} --fp16 --fill1 input_ids --input-dim @input_ids {} {} "
.
format
(
batch
,
seq_len
)
prec_str
,
batch
,
seq_len
)
if
"squad"
in
model
:
params
=
"--fill1 input_ids:0 unique_ids_raw_output___9:0 input_mask:0 segment_ids:0 --input-dim @input_ids:0 {} 256 @input_mask:0 {} 256 @segment_ids:0 {} 256 --fp16 {}"
.
format
(
batch
,
batch
,
batch
,
prec_str
)
out
=
subprocess
.
run
(
out
=
subprocess
.
run
(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g {} | grep
\'
ck_gemm.*: \[{{
\'
| sort -u >> {}'
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g {} | grep
\'
ck_gemm.*: \[{{
\'
| sort -u >> {}'
.
format
(
model
,
params
,
log_file
),
.
format
(
model
,
params
,
log_file
),
...
@@ -96,7 +101,7 @@ def tune_models(models, batch_sizes, seq_len, n, existing):
...
@@ -96,7 +101,7 @@ def tune_models(models, batch_sizes, seq_len, n, existing):
def
run
(
args
):
def
run
(
args
):
tune_models
(
args
.
models
,
args
.
batch_sizes
,
args
.
sequence_length
,
args
.
n
,
tune_models
(
args
.
models
,
args
.
batch_sizes
,
args
.
sequence_length
,
args
.
n
,
args
.
update
)
args
.
update
,
args
.
quantize_int8
)
run
(
parse_args
())
run
(
parse_args
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment