Commit 5d718e6b authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 6d9a07d7 9ce18b04
......@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on
......
......@@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union
{
......@@ -119,10 +116,15 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
val.fval = x;
uint32_t ival = 0;
const float max_fp8 = 240.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
......@@ -166,10 +168,15 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
val.fval = x;
uint32_t ival = 0;
const float max_bf8 = 57344.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
......@@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union
{
......@@ -218,8 +222,13 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
val.fval = x;
uint32_t ival = 0;
const float max_fp8 = 240.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8);
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
......@@ -263,8 +272,13 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
val.fval = x;
uint32_t ival = 0;
const float max_bf8 = 57344.0f;
// if x is not +/- infinity or nan
if((val.i32val & NumericUtils<float>::nan_mask) != NumericUtils<float>::Inf)
// clip float value
val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8);
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment