Unverified Commit a8629a98 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into gemm_v2r3_kpad_fix

parents 8dc713ea 94bfa502
......@@ -27,6 +27,12 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
......@@ -69,18 +75,36 @@ struct PassThrough
y = type_convert<bhalf_t>(x);
}
template <>
__host__ __device__ void operator()<float, half_t>(float& y, const half_t& x) const
{
y = type_convert<float>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{
y = type_convert<half_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
......@@ -89,6 +113,7 @@ struct PassThrough
}
#endif
#if defined CK_ENABLE_FP8
template <>
__host__ __device__ void operator()<f8_t, f8_t>(f8_t& y, const f8_t& x) const
{
......@@ -118,6 +143,7 @@ struct PassThrough
{
y = type_convert<f8_t>(x);
}
#endif
};
struct UnaryConvert
......@@ -146,6 +172,7 @@ struct ConvertBF16RTN
}
};
#if defined CK_ENABLE_FP8
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
......@@ -162,6 +189,7 @@ struct ConvertF8SR
y = f8_convert_sr<Y>(x);
}
};
#endif
struct Scale
{
......@@ -412,14 +440,19 @@ struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value,
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value,
"Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value,
"Data type is not supported by this operation!");
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
};
float beta_ = 1.0f;
......
This diff is collapsed.
......@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
......
......@@ -4,7 +4,8 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
......
......@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
dst_vector.template AsType<DstData>()(i) = v;
});
const bool is_dst_valid =
......@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
SrcData v;
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
dst_buf(Number<dst_offset>{}) = v;
});
});
}
......
This diff is collapsed.
This diff is collapsed.
......@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
}
};
#endif
} // namespace ck
#endif
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment