Commit 123c69b7 authored by Alan Turner's avatar Alan Turner
Browse files

Enable stable diffusion unet

parent bee5f9b5
...@@ -878,7 +878,8 @@ void fuse_ops::apply(module& m) const ...@@ -878,7 +878,8 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise{}, find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_contiguous_tranpose_precompile{}, // Commented out as workaround for reshape error when running Unet
// find_contiguous_tranpose_precompile{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -155,13 +155,9 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -155,13 +155,9 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{ {
static std::string get_layout(const shape& s) static std::string get_layout(const shape& s)
{ {
if(not s.transposed()) return s.strides().back() == 1
return "ck::tensor_layout::gemm::RowMajor"; ? "ck::tensor_layout::gemm::RowMajor"
: "ck::tensor_layout::gemm::ColumnMajor";
auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2]
? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
} }
static std::string get_type(const shape& s) static std::string get_type(const shape& s)
......
...@@ -182,7 +182,7 @@ struct index ...@@ -182,7 +182,7 @@ struct index
} }
else else
{ {
static_assert(max_stride_iterations(n, stride) < 64); static_assert(max_stride_iterations(n, stride) < 128);
sequence(max_stride_iterations(n, stride), [&](auto... ks) { sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) { fold([&](auto d, auto k) {
auto i = start + stride * k; auto i = start + stride * k;
......
...@@ -40,9 +40,12 @@ def tune_models(models, batch_sizes, seq_len, n, existing): ...@@ -40,9 +40,12 @@ def tune_models(models, batch_sizes, seq_len, n, existing):
json_file = "ck_tuning_{}.json".format(time_stamp) json_file = "ck_tuning_{}.json".format(time_stamp)
for model in models: for model in models:
for batch in batch_sizes: for batch in batch_sizes:
params = "--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16 ".format(batch, batch)
if "bert" in model:
params = "--fill1 input_ids --input-dim @input_ids {} {} ".format(batch, seq_len)
out = subprocess.run( out = subprocess.run(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}' 'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'
.format(model, batch, seq_len, log_file), .format(model, params, log_file),
capture_output=True, capture_output=True,
check=True, check=True,
shell=True) shell=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment