"scripts/convert_original_audioldm_to_diffusers.py" did not exist on "039958eae55ff0700cfb42a7e72739575ab341f1"
Commit de330330 authored by Adam Osewski's avatar Adam Osewski
Browse files

Device gemm splitk use B2C map.

parent 5f7eda63
......@@ -114,7 +114,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
using Argument = typename GridwiseGemm::Argument;
using Argument = typename GridwiseGemm::Argument;
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Invoker
struct Invoker : public BaseInvoker
......@@ -138,8 +139,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
"setting");
}
const auto b2c_map = DefaultBlock2CTileMap{};
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg);
std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch);
const auto K0 = karg.K0;
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
......@@ -152,7 +154,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType)));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map);
};
if(has_main_k0_block_loop)
......@@ -162,7 +164,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set>;
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;
Run(kernel);
}
......@@ -171,7 +174,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd>;
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;
Run(kernel);
}
......@@ -183,7 +187,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set>;
InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>;
Run(kernel);
}
......@@ -192,7 +197,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd>;
InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>;
Run(kernel);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment