Commit 6252d207 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into ck_tile/fav3_fwd_sept

parents eed60199 e07f1108
......@@ -130,6 +130,8 @@ ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'"
RUN sh -c "echo compiler commit = '$compiler_commit'"
ARG DISABLE_CACHE=0
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
......
......@@ -94,7 +94,7 @@ def getDockerImage(Map conf=[:]){
env.DOCKER_BUILDKIT=1
def prefixpath = conf.get("prefixpath", "/opt/rocm")
def no_cache = conf.get("no_cache", false)
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
if(no_cache)
{
dockerArgs = dockerArgs + " --no-cache "
......@@ -124,7 +124,7 @@ def buildDocker(install_prefix){
checkout scm
def image_name = getDockerImageName()
echo "Building Docker for ${image_name}"
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' "
echo "Build Args: ${dockerArgs}"
try{
......
......@@ -305,6 +305,14 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
}
#endif
}
else
{
// When the Problem Type and Problem Size does not fit.
std::cerr << gemm.GetTypeString() << ": the instance does not support the problem config."
<< std::endl;
return true;
}
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
......
......@@ -355,12 +355,39 @@ struct UnaryDivide
__host__ __device__ void operator()(T& y, const T& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
template <>
__host__ __device__ void operator()<half_t>(half_t& y, const half_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<half_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<bhalf_t>(x_ / divider_f_);
};
template <>
__host__ __device__ void operator()<f8_t>(f8_t& y, const f8_t& x) const
{
float x_ = type_convert<float>(x);
float divider_f_ = type_convert<float>(divider_);
y = type_convert<f8_t>(x_ / divider_f_);
};
int32_t divider_ = 1;
};
......
......@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const int32_t& idx, FloatC& reg_c) const
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t idx_part,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const
{
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, idx, reg_c);
intrin_smfmac_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run<FloatC, idx_part>(
a, b, idx, reg_c);
}
};
......@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!");
static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) {
smfmac_instr.template run<MPerXdlops, NPerXdlops>(
p_a_wave[k], p_b_wave[k], idx[k], p_c_thread);
smfmac_instr.template run<MPerXdlops, NPerXdlops, k % 4>(
p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread);
});
}
......
......@@ -9,16 +9,18 @@ namespace ck {
template <index_t MPerWave, index_t NPerWave>
struct intrin_smfmac_f32_16x16x32f16;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template <>
struct intrin_smfmac_f32_16x16x32f16<16, 16>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template <>
struct intrin_smfmac_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template <>
struct intrin_smfmac_f32_32x32x16f16<32, 32>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const half4_t& reg_a, const half8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template <>
struct intrin_smfmac_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
template <class FloatC, index_t abid = 0>
__device__ static void
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const int32_t& reg_idx, FloatC& reg_c)
Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c)
{
#if defined(__gfx94__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, 0);
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], reg_idx, 0, abid);
#else
ignore = reg_a;
ignore = reg_b;
......
......@@ -52,12 +52,28 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value || is_same<T, half_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<f8_t>(a_ + b_);
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<half_t>(a_ + b_);
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
......@@ -112,12 +128,28 @@ struct Mul
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value || is_same<T, half_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<f8_t>(a_ * b_);
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
a = type_convert<half_t>(a_ * b_);
}
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
......@@ -137,6 +169,16 @@ struct Max
float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val);
}
if constexpr(is_same_v<T, f8_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<f8_t>(val);
}
if constexpr(is_same_v<T, half_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<half_t>(val);
}
else
{
return NumericLimits<T>::Lowest();
......@@ -154,8 +196,7 @@ struct Max
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
......@@ -171,12 +212,29 @@ struct Max
a = b;
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b)
......@@ -197,6 +255,30 @@ struct Max
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
};
struct Min
......@@ -209,6 +291,16 @@ struct Min
float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val);
}
else if constexpr(is_same_v<T, half_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<half_t>(val);
}
else if constexpr(is_same_v<T, f8_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<f8_t>(val);
}
else
{
return NumericLimits<T>::Max();
......@@ -227,8 +319,7 @@ struct Min
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
is_same<T, int32_t>::value || is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b)
......@@ -244,6 +335,24 @@ struct Min
a = b;
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -270,6 +379,30 @@ struct Min
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
};
struct AMax
......@@ -299,6 +432,15 @@ struct AMax
a = b;
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
......@@ -313,6 +455,18 @@ struct AMax
changed = true;
}
}
__host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
};
template <typename T>
......@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value ||
is_same<DataType, f8_t>::value;
};
template <typename DataType>
......@@ -361,7 +516,7 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Add,
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
is_same<DataType, int32_t>::value || is_same<DataType, f8_t>::value;
};
} // namespace reduce
......
......@@ -29,9 +29,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK0>,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
......@@ -62,9 +62,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kVHeaddim,
Problem::BlockFmhaShape::BlockTile::kK1>,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
......@@ -94,9 +94,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kK2>,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
......@@ -127,9 +127,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kN0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK3>,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
......@@ -159,9 +159,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::BlockTile::kM0,
Problem::BlockFmhaShape::BlockTile::kQKHeaddim,
Problem::BlockFmhaShape::BlockTile::kK4>,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
......
......@@ -25,7 +25,7 @@ struct GemmKernel
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr index_t KernelBlockSize = GemmPipeline::KernelBlockSize;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
......
......@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
......
......@@ -195,7 +195,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
......@@ -204,7 +204,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t M1 = KernelBlockSize / get_warp_size();
constexpr index_t M1 = kBlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = kMPerBlock / (M2 * M1);
......@@ -217,7 +217,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t M0 = KernelBlockSize / get_warp_size();
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M1 = kMPerBlock / (M2 * M0);
return make_static_tile_distribution(
......@@ -235,7 +235,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
......@@ -244,7 +244,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
#if 1 // coalesce reading for each blocks
constexpr index_t N1 = KernelBlockSize / get_warp_size();
constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1);
......@@ -257,7 +257,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
sequence<1, 2>,
sequence<0, 1>>{});
#else // coalesce reading for each warps
constexpr index_t N0 = KernelBlockSize / get_warp_size();
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = kNPerBlock / (N2 * N0);
return make_static_tile_distribution(
......
......@@ -19,7 +19,7 @@ struct BlockGemmPipelineAGmemBGmemCRegV2
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t KernelBlockSize = Problem::KernelBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
......
......@@ -23,10 +23,10 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t KernelBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_avgpool_2D_bwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, BF16, BF16, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP16
void add_device_avgpool_2D_bwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F16, F16, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP8
void add_device_avgpool_2D_bwd_nhwc_f8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F8, F8, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_avgpool_2D_bwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, F32, F32, NHWC, NHWC>>>&);
#endif
#ifdef CK_ENABLE_INT8
void add_device_avgpool_2D_bwd_nhwc_int8_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<2, I8, I8, NHWC, NHWC>>>&);
#endif
template <typename DOutDataType, typename DInDataType, typename InLayout, typename OutLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceAvgPoolBwd<2, DOutDataType, DInDataType, InLayout, OutLayout>>
{
using DeviceOp = DeviceAvgPoolBwd<2, DOutDataType, DInDataType, InLayout, OutLayout>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NHWC> && is_same_v<OutLayout, NHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16>)
add_device_avgpool_2D_bwd_nhwc_f16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<DOutDataType, BF16> && is_same_v<DInDataType, BF16>)
add_device_avgpool_2D_bwd_nhwc_bf16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<DOutDataType, F32> && is_same_v<DInDataType, F32>)
add_device_avgpool_2D_bwd_nhwc_f32_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP8
else if constexpr(is_same_v<DOutDataType, F8> && is_same_v<DInDataType, F8>)
add_device_avgpool_2D_bwd_nhwc_f8_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<DOutDataType, I8> && is_same_v<DInDataType, I8>)
add_device_avgpool_2D_bwd_nhwc_int8_instances(op_ptrs);
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -23,6 +23,11 @@ void add_device_maxpool_bwd_bf16_instances(
void add_device_maxpool_bwd_f32_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<F32, I32, F32>>>&);
#endif
#ifdef CK_ENABLE_INT8
void add_device_maxpool_bwd_int8_instances(
std::vector<std::unique_ptr<DeviceMaxPoolBwd<I8, I32, I8>>>&);
#endif
template <typename DOutDataType, typename IndexDataType, typename DInDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>>
......@@ -32,6 +37,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16> &&
is_same_v<IndexDataType, I32>)
......@@ -47,6 +53,11 @@ struct DeviceOperationInstanceFactory<
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_f32_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<DOutDataType, I8> && is_same_v<DInDataType, I8> &&
is_same_v<IndexDataType, I32>)
add_device_maxpool_bwd_int8_instances(op_ptrs);
#endif
return op_ptrs;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
static constexpr auto InOutRank = 4;
static constexpr auto WindowRank = 2;
static constexpr auto MaxOp = ck::ReduceTensorOp::MAX;
static constexpr auto AvgOp = ck::ReduceTensorOp::AVG;
#ifdef CK_ENABLE_FP16
// FP16
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, AvgOp, false>>>&);
// FP16 - return index
void add_device_pool2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F16, F16, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_BF16
// BF16
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, AvgOp, false>>>&);
// BF16 - return index
void add_device_pool2d_fwd_nhwc_index_bf16_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, BF16, BF16, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
#ifdef CK_ENABLE_FP32
// FP32
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, false>>>&);
void add_device_pool2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, AvgOp, false>>>&);
// FP32 - return index
void add_device_pool2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<
DevicePoolFwd<InOutRank, WindowRank, F32, F32, I32, NHWC, NHWC, MaxOp, true>>>&);
#endif
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>>
{
using DeviceOp = DevicePoolFwd<InOutRank,
WindowRank,
InDataType,
OutDataType,
IndexDataType,
InLayout,
OutLayout,
ReduceOpId,
OutputIndex>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NHWC> && is_same_v<OutLayout, NHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, F16> && is_same_v<OutDataType, F16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<OutDataType, BF16> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_bf16_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_bf16_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<OutDataType, F32> &&
is_same_v<IndexDataType, I32>)
{
if constexpr(OutputIndex && ReduceOpId == MaxOp)
{
add_device_pool2d_fwd_nhwc_index_f32_instances(op_ptrs);
}
else
{
add_device_pool2d_fwd_nhwc_f32_instances(op_ptrs);
}
}
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
set(DEVICE_AVGPOOL_2D_BWD_INSTANCES)
list(APPEND DEVICE_AVGPOOL_2D_BWD_INSTANCES device_avg_pool2d_bwd_nhwc_bf16_instance.cpp
device_avg_pool2d_bwd_nhwc_f16_instance.cpp
device_avg_pool2d_bwd_nhwc_f32_instance.cpp
device_avg_pool2d_bwd_nhwc_f8_instance.cpp
device_avg_pool2d_bwd_nhwc_int8_instance.cpp)
add_instance_library(device_avg_pool2d_bwd_instance ${DEVICE_AVGPOOL_2D_BWD_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment