Commit 13af8cc4 authored by aska-0096's avatar aska-0096
Browse files

add inline asm for wmmaop test

parent e43df26a
......@@ -97,7 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
{
__shared__ src_t p_shared[16*16*2];
__shared__ src_t p_shared[16 * 16 * 2];
const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
......@@ -115,7 +115,7 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const int lane = lIdx % 16;
const int lane = lIdx % 16;
const int lane_lo = lIdx / 2;
const int lane_hi = lIdx % 2;
for(int ele = 0; ele < 8; ++ele)
......@@ -129,15 +129,15 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
}
__syncthreads();
for(int ele = 0; ele < 8; ++ele)
{
p_shared[8*16*lane_hi + 8 * lane_lo + ele] = a_temp[ele];
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele] = a_temp[ele];
}
for(int ele = 0; ele < 8; ++ele)
{
p_shared[8*16*lane_hi + 8 * lane_lo + ele + 16*16] = b_temp[ele];
p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele];
}
asm volatile("\
......@@ -147,12 +147,12 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
for(int ele = 0; ele < 16; ++ele)
{
b_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8 + 16*16];
b_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8 + 16 * 16];
}
// follow origin design
for(int ele = 0; ele < 16; ++ele)
{
a_frag[ele] = p_shared[(ele/8) * 16*8 + 8 * lane + ele%8];
a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8];
}
asm volatile("\
......@@ -163,6 +163,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// sync threads, similar to mma_sync
// __syncthreads();
builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_);
// since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16
// when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag,
// c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType<float8_t>()(Number<0>{}));
__syncthreads();
// wait for results, similar to mma_sync
static_for<0, 8, 1>{}([&](auto ele) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment