Unverified Commit 5aa3c344 authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_layernorm_welford

parents 7fefc966 9d8f834a
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include <cmath>
#include <string> #include <string>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
......
...@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -506,12 +506,12 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(batch_stride_A_); return static_cast<long_index_t>(g_idx) * batch_stride_A_;
} }
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(batch_stride_B_); return static_cast<long_index_t>(g_idx) * batch_stride_B_;
} }
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
...@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -519,8 +519,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
std::array<long_index_t, NumDTensor> ds_offset; std::array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = ds_offset[i] = static_cast<long_index_t>(g_idx) *
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(g_idx, 0, 0)); ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
}); });
return ds_offset; return ds_offset;
...@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -528,7 +528,8 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{ {
return e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return static_cast<long_index_t>(g_idx) *
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
} }
private: private:
...@@ -549,10 +550,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -549,10 +550,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -586,12 +583,19 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -586,12 +583,19 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched>;
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype( // desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>; GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype( using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>; GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))>;
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -719,10 +723,9 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -719,10 +723,9 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_; ds_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map // block-to-e-tile map
Block2ETileMap block_2_etile_map_; Block2ETileMap block_2_etile_map_;
...@@ -786,10 +789,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle ...@@ -786,10 +789,10 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
CDEElementwiseOperation, CDEElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
typename GridwiseGemm::DefaultBlock2ETileMap, DeviceOp::Block2ETileMap,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
......
...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout ...@@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!DeviceOp::IsSupportedArgument(arg))
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! unsupported argument");
} }
const index_t grid_size = const index_t grid_size =
......
...@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp ...@@ -237,10 +237,6 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_M_K,
BGridDesc_N_K,
DsGridDesc_M_N,
EGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment