Commit d6379e50 authored by zhanghj2's avatar zhanghj2
Browse files

实现了scale使用buffer load读取

parent bdf0140b
......@@ -284,6 +284,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
typedef unsigned char __hip_fp8_storage_t;
typedef __fp16 __fp16x8_t __attribute__((ext_vector_type(8)));
typedef __fp16 __fp16x4_t __attribute__((ext_vector_type(4)));
typedef int v2i __attribute__((ext_vector_type(2)));
union Fp8_storage{
__fp16x8_t data_128;
......@@ -390,6 +391,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
} else {
gK_base = k_ptr + offset_k + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*2);;
static_assert(NUM_SCALES == 8);
#if 1
uint8_t* scale_ptr = k_ptr + offset_k + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*2) + rel_idx_in_block*NUM_SCALES;
if (token_index == -1)
{
......@@ -419,6 +421,37 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
}
}
#else
struct PtrWrapper {
uint32_t former;
uint32_t latter;
};
PtrWrapper glob_ptr;
*(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(k_ptr + offset_k + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*2));
uint32x4_t global_addr = {0};
global_addr[0] = (glob_ptr.former);
global_addr[1] = (glob_ptr.latter);
global_addr[2] = 0x80000000;
global_addr[3] = 0x00020000;
int offset_v = token_index == -1 ? -1: rel_idx_in_block*NUM_SCALES;
union Scale_e8m0
{
v2i tmp;
__hip_fp8_storage_t fp8_e8m0[NUM_SCALES];
};
Scale_e8m0 scale_e8m0;
scale_e8m0.tmp = __builtin_amdgcn_buffer_load_dwordx2(global_addr, 0, offset_v, 0, 0);
union Fp32{
uint32_t as_bits;
float as_value;
};
Fp32 fp32;
for (int i = 0; i < NUM_SCALES - 1; i++)
{
fp32.as_bits = (scale_e8m0.fp8_e8m0[i] << 23);
scales[i] = fp32.as_value;
}
#endif
// if (block0() && threadIdx.x < 64)
// {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment