Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5cef60b8
Commit
5cef60b8
authored
May 25, 2023
by
Alan Turner
Browse files
Cleanup
parent
2f2757ac
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
90 deletions
+35
-90
src/quantization.cpp
src/quantization.cpp
+21
-21
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+2
-45
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+12
-24
No files found.
src/quantization.cpp
View file @
5cef60b8
...
...
@@ -112,28 +112,28 @@ void quantize_int8(program& prog,
max_abs_vals
->
resize
(
param_num
,
0.0
f
);
// use the calibration data to compute the quantization scale
//
auto capture_prog = prog;
//
capture_prog.compile(t);
auto
capture_prog
=
prog
;
capture_prog
.
compile
(
t
);
//
//
use all calibration data to run the program to calculate the
//
//
quantization scale and shift
//
for(auto&& arg : calibration)
//
{
//
parameter_map m;
//
for(auto&& x : capture_prog.get_parameter_shapes())
//
{
//
if(arg.count(x.first) > 0)
//
{
//
assert(x.second == arg.at(x.first).get_shape());
//
m[x.first] = t.copy_to(arg.at(x.first));
//
}
//
else
//
{
//
m[x.first] = t.allocate(x.second);
//
}
//
}
//
capture_prog.eval(m);
//
}
// use all calibration data to run the program to calculate the
// quantization scale and shift
for
(
auto
&&
arg
:
calibration
)
{
parameter_map
m
;
for
(
auto
&&
x
:
capture_prog
.
get_parameter_shapes
())
{
if
(
arg
.
count
(
x
.
first
)
>
0
)
{
assert
(
x
.
second
==
arg
.
at
(
x
.
first
).
get_shape
());
m
[
x
.
first
]
=
t
.
copy_to
(
arg
.
at
(
x
.
first
));
}
else
{
m
[
x
.
first
]
=
t
.
allocate
(
x
.
second
);
}
}
capture_prog
.
eval
(
m
);
}
// print the quantization parameters in only the main module
if
(
enabled
(
MIGRAPHX_INT8_QUANTIZATION_PARAMS
{}))
...
...
src/targets/gpu/fuse_ck.cpp
View file @
5cef60b8
...
...
@@ -96,8 +96,8 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
//
if(a.lens().back() > 2048)
//
return false;
if
(
a
.
lens
().
back
()
>
2048
)
return
false
;
return
true
;
}
...
...
@@ -160,20 +160,6 @@ struct find_ck_gemm_pointwise_int8
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
next_ins
=
std
::
next
(
ins
);
// if (next_ins->name() == "quant_dot")
// {
// std::cout << "\nins: ";
// ins->debug_print();
// std::cout << "\ngemm_ins: ";
// gemm_ins->debug_print();
// std::cout << "\nx_ins: ";
// x_ins->debug_print();
// std::cout << "\nnext: ";
// next_ins->debug_print();
// mpm.get_module().debug_print();
// }
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
...
...
@@ -182,13 +168,6 @@ struct find_ck_gemm_pointwise_int8
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
assert
(
gemm_it
!=
inputs
.
end
());
// if(ins->get_shape().type() != shape::half_type)
// return;
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM before: " << std::endl;
// pm->debug_print();
// }
if
(
gemm_idx
!=
0
)
{
auto
first_param
=
pm
->
get_parameter
(
names
[
0
]);
...
...
@@ -201,31 +180,9 @@ struct find_ck_gemm_pointwise_int8
pm
->
remove_instruction
(
first_param
);
pm
->
remove_instruction
(
gemm_param
);
}
// if (next_ins->name() == "reshape")
// {
// std::cout << "PM after: " << std::endl;
// pm->debug_print();
// }
inputs
.
erase
(
gemm_it
);
inputs
.
insert
(
inputs
.
begin
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
// std::cout << "Next_ins inputs: " << std::endl;
// for (auto& in : next_ins->inputs())
// {
// in->debug_print();
// }
// auto out_shape = compute_shape(ck_gemm_int8{}, inputs, {pm});
// instruction::replace(ins, ck_gemm_int8{}, out_shape.with_type(migraphx::shape::half_type), inputs, {pm});
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_int8
{},
inputs
,
{
pm
});
// std::cout << "Next_ins inputs (post replace): " << std::endl;
// for (auto& in : std::next(ins)->inputs())
// {
// in->debug_print();
// }
// if (next_ins->name() == "softmax" or next_ins->name() == "reshape")
// {
// std::cout << "After replace: " << std::endl;
// mpm.get_module().debug_print();
// }
}
};
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
5cef60b8
...
...
@@ -326,22 +326,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
.
with_type
(
a_shape
.
type
())}));
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
// if (not (get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9]))
// {
// std::cout << get_layout(a_shape) << " - " << x[0] <<std::endl;
// std::cout << get_layout(b_shape) << " - " << x[1] <<std::endl;
// std::cout << get_layout(c_shape) << " - " << x[3] <<std::endl;
// std::cout << get_type(a_shape) << " - " << x[4] <<std::endl;
// std::cout << get_type(b_shape) << " - " << x[5] <<std::endl;
// std::cout << get_type(c_shape) << " - " << x[9] <<std::endl;
// }
/* return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_type
(
b_shape
)
==
x
[
5
];
...
...
@@ -354,6 +338,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip
.
set_ds_op
(
v
.
at
(
"post"
).
to
<
std
::
string
>
());
}
if
(
a_shape
.
type
()
==
shape
::
int8_type
)
{
ip
.
set_e_type
(
get_type
(
c_shape
));
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
auto
s
)
{
return
get_type
(
s
)
==
"ck::half_t"
;
}))
{
...
...
@@ -363,6 +349,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
ip
.
set_c_scalar_per_vec
(
"4"
);
}
}
auto
padding
=
ip
.
get_pad
(
config
);
...
...
@@ -407,7 +395,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}});
// std::cout << options.kernel_name << ": " << std::endl;
return
compile_hip_code_object
(
src
,
options
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment