"src/include/blockwise_2d_tensor_op.hpp" did not exist on "b2888adfbe103ae3d9006af87d5871b69cbf00ba"
Commit 1c1da090 authored by rtmadduri's avatar rtmadduri
Browse files

fixing sequence length

parent a6f99d8f
......@@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 3, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 3, 8, 8, 1, 1, 1, S<32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
......
......@@ -200,9 +200,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
ComputeTypeA,
ComputeTypeB>;
// Block2CTileMap configuration parameter.
using Block2ETileMap = typename GridwiseGemm::Block2CTileMap;
;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
using KernelArgument = typename GridwiseGemm::Argument;
......@@ -279,14 +277,15 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
// const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(M, m_padded, N, n_padded, stride_c);
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(M, N, K_BATCH);
const auto local_b2c_tile_map = Block2ETileMap{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
// const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t grid_size_grp = gdx * gdy * gdz;
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
......@@ -328,21 +327,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N);
// const index_t m_padded = GridwiseGemm::CalculateMPadded(karg.M);
// const index_t n_padded = GridwiseGemm::CalculateNPadded(karg.N);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
karg.M, karg.N, m_padded, n_padded, karg.StrideC);
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N, karg.KBatch);
const auto local_b2c_tile_map = Block2ETileMap{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const auto local_b2c_tile_map = Block2ETileMap{gdx, gdy, gdz};
// const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t grid_size_grp = gdx * gdy * gdz;
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
......@@ -440,9 +439,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size(),
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation);
PassThrough{},
PassThrough{},
PassThrough{});
};
constexpr index_t minimum_occupancy =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment