"scripts/convert_original_controlnet_to_diffusers.py" did not exist on "31444f5790a8f91fc2caf3995fadc561e23fb279"
Commit 3002a1a5 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use custom b2c in wavelet and default in regular gemm.

parent dc0bae32
......@@ -359,7 +359,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
#if 0
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
#else
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_);
#endif
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
auto launch_kernel = [&](auto has_main_k_block_loop) {
......
......@@ -484,8 +484,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
#if 1
const index_t grid_size =
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
#else
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
#endif
// const index_t grid_size =
// arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......
......@@ -204,7 +204,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& /*block_2_ctile_map*/)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -228,10 +228,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return false;
}
#if 0
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{
return false;
}
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
......@@ -263,6 +265,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_grid_desc_mblock_mperblock_nblock_nperblock;
}
#if 1
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
......@@ -270,6 +273,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
}
#else
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M01 = I1;
constexpr auto N01 = I1;
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(M0, M01)),
make_unmerge_transform(make_tuple(N0, N01))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}));
const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto cblockid_to_m0_n0_block_cluster_adaptor =
chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
return cblockid_to_m0_n0_block_cluster_adaptor;
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
#endif
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
......@@ -302,6 +354,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
#if 1
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
......@@ -309,7 +362,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
return;
}
#endif
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
......
......@@ -208,6 +208,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
return false;
}
#if 0 // debug
// check block-to-E-tile
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
#endif
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
......@@ -230,6 +238,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
}
#if 0
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
#else
// return block_id to E matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
......@@ -276,6 +294,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
return grid_size;
}
#endif
// A desc for source in blockwise copy
__host__ __device__ static constexpr auto
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment