"scripts/wan/run_wan_t2v_causvid.sh" did not exist on "7c3da5c00225fda02c4a391afde62d6c006da602"
Commit 59a53257 authored by Manupa Karunaratne's avatar Manupa Karunaratne
Browse files

* clang-format

parent 83b9164b
......@@ -206,8 +206,8 @@ auto is_mlir_conv(mlir_mode mode)
}
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape)
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
......@@ -217,20 +217,24 @@ std::unordered_map<instruction_ref, instruction_ref>
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
instruction_ref mbcast =
mm->add_instruction(make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
}
std::vector<instruction_ref> fold_pointwise_mod(instruction_ref pm_ins, module_ref parent_mod, const std::unordered_map<instruction_ref, instruction_ref>& ins_map){
std::vector<instruction_ref>
fold_pointwise_mod(instruction_ref pm_ins,
module_ref parent_mod,
const std::unordered_map<instruction_ref, instruction_ref>& ins_map)
{
auto* pm = pm_ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
std::unordered_map<instruction_ref, instruction_ref> param_map =
create_param_map_with_literals(parent_mod, pm, pm_ins->get_shape());
std::transform( names.begin(),
std::transform(names.begin(),
names.end(),
pm_ins->inputs().begin(),
std::inserter(param_map, param_map.end()),
......@@ -310,8 +314,6 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
return false;
}
struct find_mlir_fused_ops
{
mlir_mode conv_mode = mlir_mode::none;
......@@ -415,7 +417,8 @@ struct find_mlir_standalone_attention_op
auto [new_top_ins, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, top_ins);
inputs.insert(inputs.begin(), top_inputs.begin(), top_inputs.end());
ins_map[top_ins] = new_top_ins;
if(r.instructions.find("scale") != r.instructions.end()){
if(r.instructions.find("scale") != r.instructions.end())
{
auto scale_ins = r.instructions["scale"];
new_top_ins = fold_pointwise_mod(scale_ins, mm, ins_map)[0];
std::copy_if(scale_ins->inputs().begin(),
......@@ -428,7 +431,8 @@ struct find_mlir_standalone_attention_op
r.instructions["gemm1"]->inputs().end(),
std::inserter(ins_map, ins_map.end()),
[&](auto old_ins) {
if(old_ins == r.instructions["softmax"]){
if(old_ins == r.instructions["softmax"])
{
return std::make_pair(old_ins, softmax);
}
inputs.push_back(old_ins);
......@@ -441,7 +445,8 @@ struct find_mlir_standalone_attention_op
ins_map[r.instructions["gemm1"]] = new_gemm1;
auto ins_to_replace = new_gemm1;
auto ins_to_be_replaced = r.instructions["gemm1"];
if(r.instructions.find("trailing_pm") != r.instructions.end()){
if(r.instructions.find("trailing_pm") != r.instructions.end())
{
ins_to_replace = fold_pointwise_mod(r.instructions["trailing_pm"], mm, ins_map)[0];
std::copy_if(r.instructions["trailing_pm"]->inputs().begin(),
r.instructions["trailing_pm"]->inputs().end(),
......@@ -454,28 +459,42 @@ struct find_mlir_standalone_attention_op
ins_to_be_replaced, mlir_op{new_gemm1->get_operator()}, inputs, {mm});
}
auto matcher() const {
auto match_softmax_input = match::any_of[match::inputs()](match::name("dot").bind("gemm0"), match::name("pointwise")(match::any_of[match::inputs()](match::name("dot").bind("gemm0"))).bind("scale"));
auto is_mlir_attention = match::name("dot")(match::any_of[match::inputs()](match::name("softmax")(match_softmax_input).bind("softmax"))).bind("gemm1");
auto matcher() const
{
auto match_softmax_input = match::any_of[match::inputs()](
match::name("dot").bind("gemm0"),
match::name("pointwise")(
match::any_of[match::inputs()](match::name("dot").bind("gemm0")))
.bind("scale"));
auto is_mlir_attention =
match::name("dot")(match::any_of[match::inputs()](
match::name("softmax")(match_softmax_input).bind("softmax")))
.bind("gemm1");
return is_mlir_attention;
}
bool check(const match::matcher_result& r) const {
bool check(const match::matcher_result& r) const
{
// We are only enabling attention
// in the highest enablement mode for now
if(mode != mlir_mode::all){
if(mode != mlir_mode::all)
{
return false;
}
auto gemm0 = r.instructions["gemm0"];
// Check the pointwise mod only contains a single mul
if(r.instructions.find("scale") != r.instructions.end()){
if(r.instructions.find("scale") != r.instructions.end())
{
auto scale_pm = r.instructions["scale"];
bool found_mul = false;
for(const auto& scale_ins : *scale_pm->module_inputs().front()){
if(contains({"@param", "@literal", "@return"}, scale_ins.name())){
for(const auto& scale_ins : *scale_pm->module_inputs().front())
{
if(contains({"@param", "@literal", "@return"}, scale_ins.name()))
{
continue;
}
if(scale_ins.name() == "mul" && !found_mul){
if(scale_ins.name() == "mul" && !found_mul)
{
found_mul = true;
continue;
}
......@@ -487,7 +506,8 @@ struct find_mlir_standalone_attention_op
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
})){
}))
{
return false;
}
return true;
......@@ -495,7 +515,8 @@ struct find_mlir_standalone_attention_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
if(!check(r)){
if(!check(r))
{
return;
}
rewrite(mpm, r);
......@@ -504,13 +525,18 @@ struct find_mlir_standalone_attention_op
struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
{
auto matcher() const {
auto matcher() const
{
auto standalone_matcher = find_mlir_standalone_attention_op::matcher();
return match::name("pointwise")(match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));;
return match::name("pointwise")(
match::any_of[match::inputs()](standalone_matcher).bind("trailing_pm"));
;
}
bool check(const match::matcher_result& r) const {
if(!find_mlir_standalone_attention_op::check(r)){
bool check(const match::matcher_result& r) const
{
if(!find_mlir_standalone_attention_op::check(r))
{
return false;
}
auto trailing_pm_ins = r.instructions["trailing_pm"]; // input after contiguous
......@@ -525,7 +551,8 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
if(!check(r)){
if(!check(r))
{
return;
}
rewrite(mpm, r);
......@@ -575,8 +602,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
(enabled(MIGRAPHX_ENABLE_EXTRA_MLIR{}) or enable_extra) ? mlir_mode::fast : mlir_mode::none;
// Attention offloads; default disabled
match::find_matches(mpm,
find_mlir_attention_fused_ops{get_mode("attention", mlir_mode::none)});
match::find_matches(mpm, find_mlir_attention_fused_ops{get_mode("attention", mlir_mode::none)});
match::find_matches(mpm,
find_mlir_standalone_attention_op{get_mode("attention", mlir_mode::none)});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment