Commit 123c69b7 authored by Alan Turner's avatar Alan Turner
Browse files

Enable stable diffusion unet

parent bee5f9b5
......@@ -878,7 +878,8 @@ void fuse_ops::apply(module& m) const
find_concat_pointwise{},
find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_contiguous_tranpose_precompile{},
// Commented out as workaround for reshape error when running Unet
// find_contiguous_tranpose_precompile{},
find_commutative_broadcast{});
match::find_matches(m, find_contiguous{});
}
......
......@@ -155,13 +155,9 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static std::string get_layout(const shape& s)
{
if(not s.transposed())
return "ck::tensor_layout::gemm::RowMajor";
auto lens = s.lens();
return lens[lens.size() - 1] > lens[lens.size() - 2]
? "ck::tensor_layout::gemm::ColumnMajor"
: "ck::tensor_layout::gemm::RowMajor";
return s.strides().back() == 1
? "ck::tensor_layout::gemm::RowMajor"
: "ck::tensor_layout::gemm::ColumnMajor";
}
static std::string get_type(const shape& s)
......
......@@ -182,7 +182,7 @@ struct index
}
else
{
static_assert(max_stride_iterations(n, stride) < 64);
static_assert(max_stride_iterations(n, stride) < 128);
sequence(max_stride_iterations(n, stride), [&](auto... ks) {
fold([&](auto d, auto k) {
auto i = start + stride * k;
......
......@@ -40,9 +40,12 @@ def tune_models(models, batch_sizes, seq_len, n, existing):
json_file = "ck_tuning_{}.json".format(time_stamp)
for model in models:
for batch in batch_sizes:
params = "--input-dim @sample {} 4 64 64 @timestep 1 @encoder_hidden_states {} 64 1024 --fp16 ".format(batch, batch)
if "bert" in model:
params = "--fill1 input_ids --input-dim @input_ids {} {} ".format(batch, seq_len)
out = subprocess.run(
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g --fill1 input_ids --input-dim @input_ids {} {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'
.format(model, batch, seq_len, log_file),
'MIGRAPHX_LOG_CK_GEMM=1 ../build/bin/driver run {} -g {} | grep \'ck_gemm.*: \[{{\' | sort -u >> {}'
.format(model, params, log_file),
capture_output=True,
check=True,
shell=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment