"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "138fac703a8359460a59d20082114195759722b8"
Commit 62fdce6d authored by ltqin's avatar ltqin
Browse files

revome a b matrix k0mk1 desc in grid

parent b7c1259f
...@@ -16,8 +16,6 @@ namespace ck { ...@@ -16,8 +16,6 @@ namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename ABK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
...@@ -29,10 +27,8 @@ __global__ void ...@@ -29,10 +27,8 @@ __global__ void
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc, const void ABK0MK1GridDesc a_b_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc, const void BBK0NK1GridDesc b_b_k0_n_k1_grid_desc,
const void CONSTANT* a_b_k0_m_k1_grid_desc,
const void CONSTANT* b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor) const CBlockClusterAdaptor c_block_cluster_adaptor)
{ {
...@@ -45,8 +41,6 @@ __global__ void ...@@ -45,8 +41,6 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
...@@ -56,8 +50,6 @@ __global__ void ...@@ -56,8 +50,6 @@ __global__ void
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename ABK0MK1GridDesc, typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc, typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc, typename CM0N0M1N1M2M3M4N2GridDesc,
...@@ -69,8 +61,6 @@ __global__ void ...@@ -69,8 +61,6 @@ __global__ void
kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_a_b_k0_m_k1_grid_desc, const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
const void CONSTANT* p_b_b_k0_n_k1_grid_desc, const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
...@@ -79,10 +69,6 @@ __global__ void ...@@ -79,10 +69,6 @@ __global__ void
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast<const ABK0MK1GridDesc*>( const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast<const ABK0MK1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc)); cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc));
const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast<const BBK0NK1GridDesc*>( const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast<const BBK0NK1GridDesc*>(
...@@ -99,8 +85,6 @@ __global__ void ...@@ -99,8 +85,6 @@ __global__ void
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
...@@ -348,8 +332,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -348,8 +332,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
...@@ -362,13 +344,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -362,13 +344,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1);
const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2);
const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2);
const auto b_grid_size = CalculateGridSize(M, N); const auto b_grid_size = CalculateGridSize(M, N);
const auto k_batch_id = get_block_1d_id() / b_grid_size; const auto k_batch_id = get_block_1d_id() / b_grid_size;
const auto block_id_in_batch = get_block_1d_id() % b_grid_size; const auto block_id_in_batch = get_block_1d_id() % b_grid_size;
if(get_block_1d_id() == 2000) if(get_block_1d_id() == 200000)
printf("grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d \n", printf("grid size: %d, k0: %d, blockid: %d, threadid %d, Batch: %d block_id: %d \n",
b_grid_size, b_grid_size,
K0, K0,
...@@ -426,10 +408,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -426,10 +408,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>(a_b_k0_m_k1_grid_desc, true>(
make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), a_b_k0_m_k1_grid_desc,
a_b_k0_m_k1_block_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0, 0)); a_b_k0_m_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy =
...@@ -452,10 +435,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4 ...@@ -452,10 +435,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>(b_b_k0_n_k1_grid_desc, true>(
make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), b_b_k0_n_k1_grid_desc,
b_b_k0_n_k1_block_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0, 0)); b_b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
......
...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 2; constexpr index_t KBatch = 64;
#elif 1 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -132,15 +132,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk ...@@ -132,15 +132,15 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmB make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmB make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmB
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
......
...@@ -156,8 +156,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -156,8 +156,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v2r4<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<ABK0MK1GridDesc>, remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>, remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>, remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
...@@ -172,23 +170,17 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -172,23 +170,17 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc, b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc));
DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc)); DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc); a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc);
b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc); b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
...@@ -203,8 +195,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, ...@@ -203,8 +195,6 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment