Commit bf545630 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent f83a2f38
...@@ -639,6 +639,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -639,6 +639,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
index_t GetKPerBlock() override { return KPerBlock; } index_t GetKPerBlock() override { return KPerBlock; }
bool GetPermuteA() override { return PermuteA; }
bool GetPermuteB() override { return PermuteB; } bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a, static auto MakeArgument(const ADataType* p_a,
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
namespace ck { namespace ck {
// Fast int4x4 to half8_t data type conversion based on paper // Fast int4x4 to half8_t data type conversion based on paper
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] // [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation: // (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q) __host__ __device__ inline half4_t pki4_to_half4(int q)
{ {
...@@ -22,18 +22,22 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -22,18 +22,22 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
// Extract the two int4 at low bit and create two fp16 number.
int lo = amd_assembly_and_or_b32(q, LO, EX); int lo = amd_assembly_and_or_b32(q, LO, EX);
// Extract the two int4 at hight bit and create two fp16 number.
int hi = amd_assembly_and_or_b32(q, HI, EX); int hi = amd_assembly_and_or_b32(q, HI, EX);
const int SUB = 0xE408E408; //-8 const int SUB = 0xE408E408; // half2 {-1032, -1032}
const int MUL = 0x2c002c00; // 1/16 const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16}
const int ADD = 0xd480d480; //-79 const int ADD = 0xd480d480; // half2 {-72, -72}
vector_type<half_t, 4> res; vector_type<half_t, 4> res;
// for two fp16 from lowbit, subtract 1032 to get correct fp16 value
res.template AsType<half2_t>()(Number<0>{}) = res.template AsType<half2_t>()(Number<0>{}) =
amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB)); amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
// for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value
res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16( res.template AsType<half2_t>()(Number<1>{}) = amd_assembly_pk_fma_f16(
bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD)); bit_cast<half2_t>(hi), bit_cast<half2_t>(MUL), bit_cast<half2_t>(ADD));
......
...@@ -407,7 +407,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -407,7 +407,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
else else
{ {
// Pre-shuffled Weight // Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr index_t BK01 = KPerBlock / BK1Value; constexpr index_t BK01 = KPerBlock / BK1Value;
// const index_t BK00 = BK0 / BK01; // const index_t BK00 = BK0 / BK01;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment