Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
d21ab0f5
Commit
d21ab0f5
authored
Dec 01, 2025
by
fengzch
Browse files
fix: use rocm
parent
181f4e43
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
449 additions
and
376 deletions
+449
-376
README.md
README.md
+1
-0
nunchaku/utils.py
nunchaku/utils.py
+1
-1
src/kernels/awq/dequantize.cuh
src/kernels/awq/dequantize.cuh
+18
-7
src/kernels/awq/gemm_awq.cu
src/kernels/awq/gemm_awq.cu
+64
-58
src/kernels/utils.cuh
src/kernels/utils.cuh
+3
-3
src/kernels/zgemm/attention.cuh
src/kernels/zgemm/attention.cuh
+16
-3
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+105
-74
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+6
-3
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+38
-35
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+59
-59
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+138
-133
No files found.
README.md
View file @
d21ab0f5
...
@@ -4,5 +4,6 @@ source /usr/local/bin/fastpt -T
...
@@ -4,5 +4,6 @@ source /usr/local/bin/fastpt -T
export CPLUS_INCLUDE_PATH=/opt/dtk/roctracer/include:$CPLUS_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=/opt/dtk/roctracer/include:$CPLUS_INCLUDE_PATH
export AMDGPU_TARGETS="gfx906;gfx926;gfx928;gfx936"
export AMDGPU_TARGETS="gfx906;gfx926;gfx928;gfx936"
export FASTPT_USE_ASM=1
CXX=hipcc CC=hipcc python setup.py bdist_wheel
CXX=hipcc CC=hipcc python setup.py bdist_wheel
nunchaku/utils.py
View file @
d21ab0f5
...
@@ -308,7 +308,7 @@ def check_hardware_compatibility(quantization_config: dict, device: str | torch.
...
@@ -308,7 +308,7 @@ def check_hardware_compatibility(quantization_config: dict, device: str | torch.
if
sm
==
"120"
:
# you can only use the fp4 models
if
sm
==
"120"
:
# you can only use the fp4 models
if
quantization_config
[
"weight"
][
"dtype"
]
!=
"fp4_e2m1_all"
:
if
quantization_config
[
"weight"
][
"dtype"
]
!=
"fp4_e2m1_all"
:
raise
ValueError
(
'Please use "fp4" quantization for Blackwell GPUs. '
)
raise
ValueError
(
'Please use "fp4" quantization for Blackwell GPUs. '
)
elif
sm
in
[
"75"
,
"80"
,
"86"
,
"89"
]:
elif
sm
in
[
"75"
,
"80"
,
"86"
,
"89"
,
"92"
,
"93"
]:
if
quantization_config
[
"weight"
][
"dtype"
]
!=
"int4"
:
if
quantization_config
[
"weight"
][
"dtype"
]
!=
"int4"
:
raise
ValueError
(
'Please use "int4" quantization for Turing, Ampere and Ada GPUs. '
)
raise
ValueError
(
'Please use "int4" quantization for Turing, Ampere and Ada GPUs. '
)
else
:
else
:
...
...
src/kernels/awq/dequantize.cuh
View file @
d21ab0f5
...
@@ -12,6 +12,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
...
@@ -12,6 +12,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
#pragma once
#pragma once
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cstdint>
#include <cstdint>
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
half2
const
&
source
,
uint4
*
result
)
{
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
half2
const
&
source
,
uint4
*
result
)
{
...
@@ -67,12 +68,17 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
...
@@ -67,12 +68,17 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
// Finally, we construct the output numbers.
// Finally, we construct the output numbers.
// Convert elt_01
// Convert elt_01
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
h
[
0
]
=
__hsub
(
h
[
0
],
__float2half
(
1024.0
f
));
// Convert elt_23
// Convert elt_23
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h
[
1
]
=
__hfma
(
h
[
1
],
__float2half
(
0.0625
f
),
__float2half
(
-
64.0
f
));
// Convert elt_45
// Convert elt_45
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
h
[
2
]
=
__hsub
(
h
[
2
],
__float2half
(
1024.0
f
));
// Convert elt_67
// Convert elt_67
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
h
[
3
]
=
__hfma
(
h
[
3
],
__float2half
(
0.0625
f
),
__float2half
(
-
64.0
f
));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
__nv_bfloat162
const
&
source
,
uint4
*
result
)
{
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
__nv_bfloat162
const
&
source
,
uint4
*
result
)
{
...
@@ -121,11 +127,16 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
...
@@ -121,11 +127,16 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &so
// Finally, we construct the output numbers.
// Finally, we construct the output numbers.
// Convert elt_01
// Convert elt_01
asm
volatile
(
"fma.rn.bf16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
BF16_ONE
),
"r"
(
BF16_BIAS
));
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_23
// // Convert elt_23
asm
volatile
(
"fma.rn.bf16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
BF16_ONE
),
"r"
(
BF16_BIAS
));
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_45
// // Convert elt_45
asm
volatile
(
"fma.rn.bf16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
BF16_ONE
),
"r"
(
BF16_BIAS
));
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(BF16_ONE), "r"(BF16_BIAS));
// Convert elt_67
// // Convert elt_67
asm
volatile
(
"fma.rn.bf16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
BF16_ONE
),
"r"
(
BF16_BIAS
));
// asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(BF16_ONE), "r"(BF16_BIAS));
h
[
0
]
=
__hsub
(
h
[
0
],
__float2bfloat16_rn
(
128.0
f
));
h
[
1
]
=
__hsub
(
h
[
1
],
__float2bfloat16_rn
(
128.0
f
));
h
[
2
]
=
__hsub
(
h
[
2
],
__float2bfloat16_rn
(
128.0
f
));
h
[
3
]
=
__hsub
(
h
[
3
],
__float2bfloat16_rn
(
128.0
f
));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
src/kernels/awq/gemm_awq.cu
View file @
d21ab0f5
...
@@ -81,10 +81,10 @@ __device__ void sync_slice(int slice_id) {
...
@@ -81,10 +81,10 @@ __device__ void sync_slice(int slice_id) {
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
__inline__
__device__
uint32_t
cast_smem_ptr_to_uint
(
void
const
*
const
ptr
)
{
uint32_t
smem_int_ptr
;
uint32_t
smem_int_ptr
;
asm
(
"{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }
\n
"
//
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
:
"=r"
(
smem_int_ptr
)
//
: "=r"(smem_int_ptr)
:
"l"
(
ptr
));
//
: "l"(ptr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
smem_int_ptr
;
return
smem_int_ptr
;
}
}
...
@@ -92,38 +92,41 @@ template<typename f16_t>
...
@@ -92,38 +92,41 @@ template<typename f16_t>
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__nv_bfloat16
>::
value
,
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__nv_bfloat16
>::
value
,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types."
);
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types."
);
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
// "{%0, %1, %2, %3}, [%4];"
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
1
]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
2
]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
3
])
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
:
"r"
(
addr
));
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
template
<
typename
f16_t
>
template
<
typename
f16_t
>
__inline__
__device__
void
ldmatrix_m8n8_x4_trans_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
__inline__
__device__
void
ldmatrix_m8n8_x4_trans_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__nv_bfloat16
>::
value
,
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__nv_bfloat16
>::
value
,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types."
);
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types."
);
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
// asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
// "{%0, %1, %2, %3}, [%4];"
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
// : "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
1
]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[1]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
2
]),
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[2]),
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
3
])
// "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[3])
:
"r"
(
addr
));
// : "r"(addr));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
__inline__
__device__
void
cp_async_cg_A
(
uint32_t
smem_int_ptr
,
const
uint4
*
__restrict__
src
,
bool
mask
)
{
__inline__
__device__
void
cp_async_cg_A
(
uint32_t
smem_int_ptr
,
const
uint4
*
__restrict__
src
,
bool
mask
)
{
const
int
cp_size
=
16
;
const
int
cp_size
=
16
;
asm
volatile
(
"{"
// asm volatile("{"
" .reg .pred p;"
// " .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
// " setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global"
L2_CACHEHINT
(
128
)
" [%1], [%2], %3;"
// " @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}"
::
"r"
((
int
)
mask
),
// "}" ::"r"((int)mask),
"r"
(
smem_int_ptr
),
// "r"(smem_int_ptr),
"l"
(
src
),
// "l"(src),
"n"
(
cp_size
));
// "n"(cp_size));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
template
<
typename
f16_t
>
template
<
typename
f16_t
>
...
@@ -131,39 +134,41 @@ __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16
...
@@ -131,39 +134,41 @@ __device__ __inline__ void mma_m16n8k16(float *C_warp, f16_t *A_shared_warp, f16
template
<
>
template
<
>
__device__
__inline__
void
mma_m16n8k16
<
half
>
(
float
*
C_warp
,
half
*
A_shared_warp
,
half
*
B_shared_warp
)
{
__device__
__inline__
void
mma_m16n8k16
<
half
>
(
float
*
C_warp
,
half
*
A_shared_warp
,
half
*
B_shared_warp
)
{
asm
volatile
(
// asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
// "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
// "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
:
"=f"
(((
float
*
)
C_warp
)[
0
]),
"=f"
(((
float
*
)
C_warp
)[
1
]),
"=f"
(((
float
*
)
C_warp
)[
2
]),
"=f"
(((
float
*
)
C_warp
)[
3
])
// : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
:
"r"
(((
unsigned
*
)
A_shared_warp
)[
0
]),
// : "r"(((unsigned *)A_shared_warp)[0]),
"r"
(((
unsigned
*
)
A_shared_warp
)[
1
]),
// "r"(((unsigned *)A_shared_warp)[1]),
"r"
(((
unsigned
*
)
A_shared_warp
)[
2
]),
// "r"(((unsigned *)A_shared_warp)[2]),
"r"
(((
unsigned
*
)
A_shared_warp
)[
3
]),
// "r"(((unsigned *)A_shared_warp)[3]),
"r"
(((
unsigned
*
)
B_shared_warp
)[
0
]),
// "r"(((unsigned *)B_shared_warp)[0]),
"r"
(((
unsigned
*
)
B_shared_warp
)[
1
]),
// "r"(((unsigned *)B_shared_warp)[1]),
"f"
(((
float
*
)
C_warp
)[
0
]),
// "f"(((float *)C_warp)[0]),
"f"
(((
float
*
)
C_warp
)[
1
]),
// "f"(((float *)C_warp)[1]),
"f"
(((
float
*
)
C_warp
)[
2
]),
// "f"(((float *)C_warp)[2]),
"f"
(((
float
*
)
C_warp
)[
3
]));
// "f"(((float *)C_warp)[3]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
template
<
>
template
<
>
__device__
__inline__
void
__device__
__inline__
void
mma_m16n8k16
<
__nv_bfloat16
>
(
float
*
C_warp
,
__nv_bfloat16
*
A_shared_warp
,
__nv_bfloat16
*
B_shared_warp
)
{
mma_m16n8k16
<
__nv_bfloat16
>
(
float
*
C_warp
,
__nv_bfloat16
*
A_shared_warp
,
__nv_bfloat16
*
B_shared_warp
)
{
asm
volatile
(
// asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
// "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
// "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
:
"=f"
(((
float
*
)
C_warp
)[
0
]),
"=f"
(((
float
*
)
C_warp
)[
1
]),
"=f"
(((
float
*
)
C_warp
)[
2
]),
"=f"
(((
float
*
)
C_warp
)[
3
])
// : "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
:
"r"
(((
unsigned
*
)
A_shared_warp
)[
0
]),
// : "r"(((unsigned *)A_shared_warp)[0]),
"r"
(((
unsigned
*
)
A_shared_warp
)[
1
]),
// "r"(((unsigned *)A_shared_warp)[1]),
"r"
(((
unsigned
*
)
A_shared_warp
)[
2
]),
// "r"(((unsigned *)A_shared_warp)[2]),
"r"
(((
unsigned
*
)
A_shared_warp
)[
3
]),
// "r"(((unsigned *)A_shared_warp)[3]),
"r"
(((
unsigned
*
)
B_shared_warp
)[
0
]),
// "r"(((unsigned *)B_shared_warp)[0]),
"r"
(((
unsigned
*
)
B_shared_warp
)[
1
]),
// "r"(((unsigned *)B_shared_warp)[1]),
"f"
(((
float
*
)
C_warp
)[
0
]),
// "f"(((float *)C_warp)[0]),
"f"
(((
float
*
)
C_warp
)[
1
]),
// "f"(((float *)C_warp)[1]),
"f"
(((
float
*
)
C_warp
)[
2
]),
// "f"(((float *)C_warp)[2]),
"f"
(((
float
*
)
C_warp
)[
3
]));
// "f"(((float *)C_warp)[3]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
template
<
typename
f16_t
,
int
CTA_M
,
int
CTA_N
,
int
CTA_K
,
int
CTA_SIZE
,
int
SHARED_K_ITERS
,
int
STAGES
>
template
<
typename
f16_t
,
int
CTA_M
,
int
CTA_N
,
int
CTA_K
,
int
CTA_SIZE
,
int
SHARED_K_ITERS
,
int
STAGES
>
...
@@ -944,10 +949,11 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
...
@@ -944,10 +949,11 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
)
{
int
K
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch
();
// trap_unsupported_arch();
return
;
// return;
#endif
//#endif
// printf("LOG(INFO) %s: %d %s\n", __FILE__, __LINE__, __func__);
using
f162_t
=
typename
packed_as
<
f16_t
,
2
>::
type
;
using
f162_t
=
typename
packed_as
<
f16_t
,
2
>::
type
;
constexpr
int
NUM_WARPS
=
CTA_M
/
WARP_M
*
CTA_N
/
WARP_N
;
constexpr
int
NUM_WARPS
=
CTA_M
/
WARP_M
*
CTA_N
/
WARP_N
;
constexpr
int
CTA_SIZE
=
NUM_WARPS
*
WARP_SIZE
;
constexpr
int
CTA_SIZE
=
NUM_WARPS
*
WARP_SIZE
;
...
...
src/kernels/utils.cuh
View file @
d21ab0f5
...
@@ -171,7 +171,7 @@ inline __device__ T ldg(const T *val) {
...
@@ -171,7 +171,7 @@ inline __device__ T ldg(const T *val) {
#define float22bf162 __float22bfloat162_rn
#define float22bf162 __float22bfloat162_rn
#define bf162bf162 __bfloat162bfloat162
#define bf162bf162 __bfloat162bfloat162
inline
__device__
int16_t
bf1622int16
(
__nv_bfloat162
val
)
{
inline
__device__
int16_t
bf1622int16
(
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
=
800
float2
f_val
;
float2
f_val
;
f_val
.
x
=
max
(
min
(
__low2float
(
val
),
127.
f
),
-
128.
f
);
f_val
.
x
=
max
(
min
(
__low2float
(
val
),
127.
f
),
-
128.
f
);
f_val
.
y
=
max
(
min
(
__high2float
(
val
),
127.
f
),
-
128.
f
);
f_val
.
y
=
max
(
min
(
__high2float
(
val
),
127.
f
),
-
128.
f
);
...
@@ -203,7 +203,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
...
@@ -203,7 +203,7 @@ inline __device__ int16_t bf1622int16(__nv_bfloat162 val) {
#if ENABLE_BF16
#if ENABLE_BF16
template
<
>
template
<
>
inline
__device__
__nv_bfloat162
ldg
(
const
__nv_bfloat162
*
val
)
{
inline
__device__
__nv_bfloat162
ldg
(
const
__nv_bfloat162
*
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
=
800
return
val
[
0
];
return
val
[
0
];
#else
#else
return
__ldg
(
val
);
return
__ldg
(
val
);
...
@@ -212,7 +212,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
...
@@ -212,7 +212,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162 *val) {
template
<
>
template
<
>
inline
__device__
__nv_bfloat16
ldg
(
const
__nv_bfloat16
*
val
)
{
inline
__device__
__nv_bfloat16
ldg
(
const
__nv_bfloat16
*
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <
=
800
return
val
[
0
];
return
val
[
0
];
#else
#else
return
__ldg
(
val
);
return
__ldg
(
val
);
...
...
src/kernels/zgemm/attention.cuh
View file @
d21ab0f5
...
@@ -191,14 +191,26 @@ public:
...
@@ -191,14 +191,26 @@ public:
// set nan values to -inf
// set nan values to -inf
__device__
__forceinline__
static
half2_t
fix_nan
(
half2_t
input
)
{
__device__
__forceinline__
static
half2_t
fix_nan
(
half2_t
input
)
{
static
constexpr
float
neginf
=
-
std
::
numeric_limits
<
float
>::
infinity
();
//
static constexpr float neginf = -std::numeric_limits<float>::infinity();
/**
/**
* In accordance to the IEEE-754R standard,
* In accordance to the IEEE-754R standard,
* if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
* if one of the input parameters to fminf(), fmin(), fmaxf(), or fmax() is NaN,
* but not the other,
* but not the other,
* the result is the non-NaN parameter.
* the result is the non-NaN parameter.
*/
*/
return
__hmax2
(
input
,
half2_t
(
neginf
,
neginf
));
// return __hmax2(input, half2_t(neginf, neginf));
half_t
lo
=
__low2half
(
input
);
half_t
hi
=
__high2half
(
input
);
// Step 2: Convert to float to use isnan (HIP supports __hisnan)
// Option A: Use __hisnan if available (preferred)
half_t
neg_inf
=
__float2half
(
-
std
::
numeric_limits
<
float
>::
infinity
());
half_t
out_lo
=
__hisnan
(
lo
)
?
neg_inf
:
lo
;
half_t
out_hi
=
__hisnan
(
hi
)
?
neg_inf
:
hi
;
// Step 3: Pack back into half2_t
return
__halves2half2
(
out_lo
,
out_hi
);
}
}
__device__
__forceinline__
static
float
fix_nan
(
float
input
)
{
__device__
__forceinline__
static
float
fix_nan
(
float
input
)
{
...
@@ -511,7 +523,8 @@ public:
...
@@ -511,7 +523,8 @@ public:
if (alwaysfalse) {
if (alwaysfalse) {
dummy = clock();
dummy = clock();
}
}
// asm volatile ("membar.cta;");
asm volatile ("membar.cta;");
}
}
}
}
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
d21ab0f5
...
@@ -50,29 +50,33 @@ template<typename T>
...
@@ -50,29 +50,33 @@ template<typename T>
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
;
uint32_t
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
// "@loadpred ld.global.nc.b32 %0, [%1];"
"}"
// "}"
:
"=r"
(
data
)
// : "=r"(data)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
uint2
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
// "@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
// "}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
// : "=r"(data.x), "=r"(data.y)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
uint4
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
// asm volatile("{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
// "@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
// "}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
// : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
// : "l"(addr), "r"((int)pred));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
}
...
@@ -88,17 +92,21 @@ __device__ __forceinline__ static void store(T *addr, T val) {
...
@@ -88,17 +92,21 @@ __device__ __forceinline__ static void store(T *addr, T val) {
if
constexpr
(
shmem
)
{
if
constexpr
(
shmem
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
// asm volatile(
"st.shared.v2.b32 [%0], {%1, %2};"
::
"l"
((
addr
)),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
// "st.shared.v2.b32 [%0], {%1, %2};" ::"l"((addr)), "r"(data.x), "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
;
return
;
}
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};"
::
"l"
((
addr
)),
// asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};" ::"l"((addr)),
"r"
(
data
.
x
),
// "r"(data.x),
"r"
(
data
.
y
),
// "r"(data.y),
"r"
(
data
.
z
),
// "r"(data.z),
"r"
(
data
.
w
));
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
;
return
;
}
}
*
addr
=
val
;
*
addr
=
val
;
...
@@ -127,33 +135,39 @@ template<typename T>
...
@@ -127,33 +135,39 @@ template<typename T>
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
// "@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
// "}" ::"r"((int)pred),
"l"
(
addr
),
// "l"(addr),
"r"
(
data
));
// "r"(data));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
;
return
;
}
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
// "@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
// "}" ::"r"((int)pred),
"l"
(
addr
),
// "l"(addr),
"r"
(
data
.
x
),
// "r"(data.x),
"r"
(
data
.
y
));
// "r"(data.y));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
;
return
;
}
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
// "@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
// "}" ::"r"((int)pred),
"l"
(
addr
),
// "l"(addr),
"r"
(
data
.
x
),
// "r"(data.x),
"r"
(
data
.
y
),
// "r"(data.y),
"r"
(
data
.
z
),
// "r"(data.z),
"r"
(
data
.
w
));
// "r"(data.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
;
return
;
}
}
...
@@ -194,14 +208,16 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
...
@@ -194,14 +208,16 @@ __device__ __forceinline__ static void unused_var(T &val, bool alwaysfalse) {
__device__
__forceinline__
static
void
ldmatrix
(
const
void
*
ptr
,
uint4
&
out
)
{
__device__
__forceinline__
static
void
ldmatrix
(
const
void
*
ptr
,
uint4
&
out
)
{
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
asm
volatile
(
"ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
out
.
x
),
"=r"
(
out
.
y
),
"=r"
(
out
.
z
),
"=r"
(
out
.
w
)
:
"=r"
(
out
.
x
),
"=r"
(
out
.
y
),
"=r"
(
out
.
z
),
"=r"
(
out
.
w
)
:
"l"
((
ptr
)));
// limengmeng
:
"l"
((
ptr
)));
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
static
T
movmatrix
(
T
x
)
{
__device__
__forceinline__
static
T
movmatrix
(
T
x
)
{
asm
volatile
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
// asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
:
"=r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
))
// : "=r"(*reinterpret_cast<uint32_t *>(&x))
:
"r"
(
*
reinterpret_cast
<
uint32_t
*>
(
&
x
)));
// : "r"(*reinterpret_cast<uint32_t *>(&x)));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
x
;
return
x
;
}
}
...
@@ -215,7 +231,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
...
@@ -215,7 +231,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, false>(float2 value) {
uint32_t
result
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
// asm volatile("cvt.pack.sat.s4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
result
;
return
result
;
}
}
...
@@ -225,7 +243,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
...
@@ -225,7 +243,9 @@ __device__ __forceinline__ uint32_t quantize_float2<4, true>(float2 value) {
uint32_t
result
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
// asm volatile("cvt.pack.sat.u4.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
result
;
return
result
;
}
}
...
@@ -235,22 +255,27 @@ __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
...
@@ -235,22 +255,27 @@ __device__ __forceinline__ uint32_t quantize_float2<8, false>(float2 value) {
uint32_t
result
;
uint32_t
result
;
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v1
)
:
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.rni.s32.f32 %0, %1;"
:
"=r"
(
v2
)
:
"f"
(
value
.
y
));
asm
volatile
(
"cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;"
:
"=r"
(
result
)
:
"r"
(
v2
),
"r"
(
v1
));
// asm volatile("cvt.pack.sat.s8.s32.b32 %0, %1, %2, 0;" : "=r"(result) : "r"(v2), "r"(v1));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
result
;
return
result
;
}
}
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
uint32_t
result
;
uint32_t
result
;
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
// asm volatile("{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
// : "=r"(result)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
// : "f"(value.y), "f"(value.x));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
result
;
return
result
;
}
}
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
uint16_t
lo
,
hi
;
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
lo
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
// asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(lo) : "f"(value.y), "f"(value.x));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
hi
)
:
"f"
(
value
.
w
),
"f"
(
value
.
z
));
// asm volatile("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" : "=h"(hi) : "f"(value.w), "f"(value.z));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
uint32_t
(
lo
)
|
(
uint32_t
(
hi
)
<<
16
);
return
uint32_t
(
lo
)
|
(
uint32_t
(
hi
)
<<
16
);
}
}
...
@@ -347,15 +372,17 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
...
@@ -347,15 +372,17 @@ __device__ __forceinline__ static __nv_bfloat162 h2div(__nv_bfloat162 a, __nv_bf
};
};
__device__
__forceinline__
static
void
reduce_add
(
float
*
addr
,
float
val
)
{
__device__
__forceinline__
static
void
reduce_add
(
float
*
addr
,
float
val
)
{
asm
volatile
(
"red.relaxed.gpu.global.add.f32 [%0], %1;"
::
"l"
(
addr
),
"f"
(
val
));
// asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" ::"l"(addr), "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
__device__
__forceinline__
static
void
reduce_add_pred
(
float
*
addr
,
float
val
,
bool
pred
)
{
__device__
__forceinline__
static
void
reduce_add_pred
(
float
*
addr
,
float
val
,
bool
pred
)
{
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
// asm volatile("{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
// "@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
// "}" ::"r"((int)pred),
"l"
(
addr
),
// "l"(addr),
"f"
(
val
));
// "f"(val));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
}
}
template
<
int
cnt
,
typename
F
>
template
<
int
cnt
,
typename
F
>
...
@@ -369,9 +396,10 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
...
@@ -369,9 +396,10 @@ __device__ __forceinline__ static void unrolled_loop(F &&lambda) {
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
float
fval
;
float
fval
;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
:
"=f"
(
fval
)
// : "=f"(fval)
:
"r"
(
val
),
"n"
(
0x7FFFFF
),
"n"
(
0x4B400000
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
// : "r"(val), "n"(0x7FFFFF), "n"(0x4B400000), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
fval
-
12582912.0
f
;
return
fval
-
12582912.0
f
;
}
}
...
@@ -388,12 +416,13 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
...
@@ -388,12 +416,13 @@ __device__ __forceinline__ static half2 int2half2_fast_8192(int x, int y) {
uint32_t
ival
;
uint32_t
ival
;
uint32_t
hval
;
uint32_t
hval
;
// ival.lo = x.lo; ival.hi = y.lo;
// ival.lo = x.lo; ival.hi = y.lo;
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
//
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
ival
=
ival
>>
4
;
ival
=
ival
>>
4
;
// (val & 0x03FF03FF) ^ 0x76007600
// (val & 0x03FF03FF) ^ 0x76007600
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
// : "=r"(hval)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x76007600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
// : "r"(ival), "n"(0x03FF03FF), "n"(0x76007600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
24576.0
f
,
-
24576.0
f
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
24576.0
f
,
-
24576.0
f
));
}
}
// val in [-4096, 4095], steps of 8, round to nearest
// val in [-4096, 4095], steps of 8, round to nearest
...
@@ -407,11 +436,12 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
...
@@ -407,11 +436,12 @@ __device__ __forceinline__ static half2 int2half2_fast_4096_rn(int x, int y) {
uint32_t
hval
;
uint32_t
hval
;
// ival.lo = x.hi; ival.hi = y.hi;
// ival.lo = x.hi; ival.hi = y.hi;
// <=> divide x and y by 65536 and pack them
// <=> divide x and y by 65536 and pack them
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x7632
));
//
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x7632));
// (val & 0x03FF03FF) ^ 0x72007200
// (val & 0x03FF03FF) ^ 0x72007200
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
// : "=r"(hval)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x72007200
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
// : "r"(ival), "n"(0x03FF03FF), "n"(0x72007200), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
12288.0
f
,
-
12288.0
f
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
12288.0
f
,
-
12288.0
f
));
}
}
// val in [-512, 511]
// val in [-512, 511]
...
@@ -420,11 +450,12 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
...
@@ -420,11 +450,12 @@ __device__ __forceinline__ static half2 int2half2_fast_512(int x, int y) {
uint32_t
hval
;
uint32_t
hval
;
// ival.lo = x.lo; ival.hi = y.lo;
// ival.lo = x.lo; ival.hi = y.lo;
// <=> divide x and y by 65536 and pack them
// <=> divide x and y by 65536 and pack them
asm
volatile
(
"prmt.b32 %0, %1, %2, %3;"
:
"=r"
(
ival
)
:
"r"
(
x
),
"r"
(
y
),
"n"
(
0x5410
));
//
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ival) : "r"(x), "r"(y), "n"(0x5410));
// (val & 0x03FF03FF) ^ 0x66006600
// (val & 0x03FF03FF) ^ 0x66006600
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
// asm volatile("lop3.b32 %0, %1, %2, %3, %4;"
:
"=r"
(
hval
)
// : "=r"(hval)
:
"r"
(
ival
),
"n"
(
0x03FF03FF
),
"n"
(
0x66006600
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
// : "r"(ival), "n"(0x03FF03FF), "n"(0x66006600), "n"((0xF0 & 0xCC) ^ 0xAA));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
1536.0
f
,
-
1536.0
f
));
return
__hadd2
(
kernels
::
bit_cast
<
half2
>
(
hval
),
half2
(
-
1536.0
f
,
-
1536.0
f
));
}
}
...
...
src/kernels/zgemm/gemm_w4a4.cuh
View file @
d21ab0f5
...
@@ -247,7 +247,7 @@ public:
...
@@ -247,7 +247,7 @@ public:
// "r"(wmscale),
// "r"(wmscale),
// "n"(0),
// "n"(0),
// "h"((short)(idb * 2 + 1)));
// "h"((short)(idb * 2 + 1)));
std
::
cout
<<
__func__
<<
"mma_fp4 is not implemented for HIP yet[asm error!!!]"
<<
std
::
endl
;
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__)
;
return
out
;
return
out
;
}
}
...
@@ -334,7 +334,8 @@ public:
...
@@ -334,7 +334,8 @@ public:
dummy
=
clock
();
dummy
=
clock
();
}
}
// asm volatile ("membar.cta;");
asm
volatile
(
"membar.cta;"
);
}
}
}
}
...
@@ -916,7 +917,9 @@ public:
...
@@ -916,7 +917,9 @@ public:
}
}
// #endif
// #endif
// asm volatile ("membar.cta;");
asm
volatile
(
"membar.cta;"
);
}
}
}
}
...
...
src/kernels/zgemm/gemm_w8a8.cuh
View file @
d21ab0f5
...
@@ -10,40 +10,42 @@ public:
...
@@ -10,40 +10,42 @@ public:
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
__device__
__forceinline__
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_psum_t
psum
)
{
// packed_psum_t psum;
// packed_psum_t psum;
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
// asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
// "{%4, %5, %6, %7},"
"{%8, %9},"
// "{%8, %9},"
"{%10, %11, %12, %13};
\n
"
// "{%10, %11, %12, %13};\n"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
// : "=r"(psum.data[0]), "=r"(psum.data[1]), "=r"(psum.data[2]), "=r"(psum.data[3])
:
"r"
(
act
.
x
),
// : "r"(act.x),
"r"
(
act
.
y
),
// "r"(act.y),
"r"
(
act
.
z
),
// "r"(act.z),
"r"
(
act
.
w
),
// "r"(act.w),
"r"
(
wgt
.
x
),
// "r"(wgt.x),
"r"
(
wgt
.
y
),
// "r"(wgt.y),
// "r"(0), "r"(0), "r"(0), "r"(0)
// // "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
0
]),
// "r"(psum.data[0]),
"r"
(
psum
.
data
[
1
]),
// "r"(psum.data[1]),
"r"
(
psum
.
data
[
2
]),
// "r"(psum.data[2]),
"r"
(
psum
.
data
[
3
]));
// "r"(psum.data[3]));
asm
volatile
(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
// asm volatile("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 "
"{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
// "{%4, %5, %6, %7},"
"{%8, %9},"
// "{%8, %9},"
"{%10, %11, %12, %13};
\n
"
// "{%10, %11, %12, %13};\n"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
// : "=r"(psum.data[4]), "=r"(psum.data[5]), "=r"(psum.data[6]), "=r"(psum.data[7])
:
"r"
(
act
.
x
),
// : "r"(act.x),
"r"
(
act
.
y
),
// "r"(act.y),
"r"
(
act
.
z
),
// "r"(act.z),
"r"
(
act
.
w
),
// "r"(act.w),
"r"
(
wgt
.
z
),
// "r"(wgt.z),
"r"
(
wgt
.
w
),
// "r"(wgt.w),
// "r"(0), "r"(0), "r"(0), "r"(0)
// // "r"(0), "r"(0), "r"(0), "r"(0)
"r"
(
psum
.
data
[
4
]),
// "r"(psum.data[4]),
"r"
(
psum
.
data
[
5
]),
// "r"(psum.data[5]),
"r"
(
psum
.
data
[
6
]),
// "r"(psum.data[6]),
"r"
(
psum
.
data
[
7
]));
// "r"(psum.data[7]));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
psum
;
return
psum
;
}
}
...
@@ -418,7 +420,8 @@ public:
...
@@ -418,7 +420,8 @@ public:
// dummy = clock();
// dummy = clock();
// }
// }
// asm volatile ("membar.cta;");
asm
volatile
(
"membar.cta;"
);
}
}
}
}
...
...
src/kernels/zgemm/mma.cuh
View file @
d21ab0f5
...
@@ -110,65 +110,65 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
...
@@ -110,65 +110,65 @@ __device__ __forceinline__ static uint4 mma_m16n8kx_s32common(uint4 a, uint2 b,
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
static
constexpr
int
K
=
(
std
::
is_same_v
<
AType
,
mma_helper
::
s4
>
||
std
::
is_same_v
<
AType
,
mma_helper
::
u4
>
)
?
64
:
32
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
//
asm volatile("mma.sync.aligned.m16n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1, %2, %3},"
//
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
//
"{%4, %5, %6, %7},"
"{%8, %9},"
//
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
//
"{%10, %11, %12, %13};\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
//
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
//
: "r"(a.x),
"r"
(
a
.
y
),
//
"r"(a.y),
"r"
(
a
.
z
),
//
"r"(a.z),
"r"
(
a
.
w
),
//
"r"(a.w),
"r"
(
b
.
x
),
//
"r"(b.x),
"r"
(
b
.
y
),
//
"r"(b.y),
"r"
(
c
.
x
),
//
"r"(c.x),
"r"
(
c
.
y
),
//
"r"(c.y),
"r"
(
c
.
z
),
//
"r"(c.z),
"r"
(
c
.
w
),
//
"r"(c.w),
"n"
(
K
),
//
"n"(K),
"C"
(
AType
::
value
),
//
"C"(AType::value),
"C"
(
BType
::
value
));
//
"C"(BType::value));
#else
//
#else
asm
volatile
(
"{"
//
asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
//
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
//
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp0, tmp1},"
//
"{tmp0, tmp1},"
"{%4},"
//
"{%4},"
"{%8},"
//
"{%8},"
"{%10, %11};
\n
"
//
"{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
//
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{tmp2, tmp3},"
//
"{tmp2, tmp3},"
"{%5},"
//
"{%5},"
"{%8},"
//
"{%8},"
"{%12, %13};
\n
"
//
"{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
//
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%0, %1},"
//
"{%0, %1},"
"{%6},"
//
"{%6},"
"{%9},"
//
"{%9},"
"{tmp0, tmp1};
\n
"
//
"{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
//
"mma.sync.aligned.m8n8k%14.row.col.s32.%15.%16.s32 "
"{%2, %3},"
//
"{%2, %3},"
"{%7},"
//
"{%7},"
"{%9},"
//
"{%9},"
"{tmp2, tmp3};
\n
"
//
"{tmp2, tmp3};\n"
"}
\n
"
//
"}\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
//
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
//
: "r"(a.x),
"r"
(
a
.
y
),
//
"r"(a.y),
"r"
(
a
.
z
),
//
"r"(a.z),
"r"
(
a
.
w
),
//
"r"(a.w),
"r"
(
b
.
x
),
//
"r"(b.x),
"r"
(
b
.
y
),
//
"r"(b.y),
"r"
(
c
.
x
),
//
"r"(c.x),
"r"
(
c
.
y
),
//
"r"(c.y),
"r"
(
c
.
z
),
//
"r"(c.z),
"r"
(
c
.
w
),
//
"r"(c.w),
"n"
(
K
/
2
),
//
"n"(K / 2),
"C"
(
AType
::
value
),
//
"C"(AType::value),
"C"
(
BType
::
value
));
//
"C"(BType::value));
#endif
//
#endif
return
d
;
return
d
;
}
}
...
...
src/kernels/zgemm/mma_earlycuda.cuh
View file @
d21ab0f5
...
@@ -36,31 +36,32 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
...
@@ -36,31 +36,32 @@ using s4u4 = std::conditional_t<is_unsigned, u4, s4>;
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
__device__
__forceinline__
static
uint2
mma_m16n8k16_f16f16f16f16
(
uint4
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
uint2
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
// asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0, %1},"
// "{%0, %1},"
"{%2, %3, %4, %5},"
// "{%2, %3, %4, %5},"
"{%6, %7},"
// "{%6, %7},"
"{%8, %9};
\n
"
// "{%8, %9};\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
// : "=r"(d.x), "=r"(d.y)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#else
// #else
asm
volatile
(
"{"
// asm volatile("{"
".reg .b32 tmp0, tmp1;"
// ".reg .b32 tmp0, tmp1;"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
// "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{tmp0, tmp1},"
// "{tmp0, tmp1},"
"{%2, %3},"
// "{%2, %3},"
"{%6},"
// "{%6},"
"{%8, %9};
\n
"
// "{%8, %9};\n"
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
// "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
"{%0, %1},"
// "{%0, %1},"
"{%4, %5},"
// "{%4, %5},"
"{%7},"
// "{%7},"
"{tmp0, tmp1};"
// "{tmp0, tmp1};"
"}
\n
"
// "}\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
)
// : "=r"(d.x), "=r"(d.y)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y));
#endif
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
d
;
return
d
;
}
}
...
@@ -71,13 +72,14 @@ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2
...
@@ -71,13 +72,14 @@ __device__ __forceinline__ static uint4 mma_m16n8k16_f32f16f16f32(uint4 a, uint2
template
<
>
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
true
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
// asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
// "{%4, %5, %6, %7},"
"{%8, %9},"
// "{%8, %9},"
"{%10, %11, %12, %13};
\n
"
// "{%10, %11, %12, %13};\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
d
;
return
d
;
}
}
#endif
#endif
...
@@ -85,31 +87,32 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
...
@@ -85,31 +87,32 @@ __device__ __forceinline__ uint4 mma_m16n8k16_f32f16f16f32<true>(uint4 a, uint2
template
<
>
template
<
>
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
__device__
__forceinline__
uint4
mma_m16n8k16_f32f16f16f32
<
false
>
(
uint4
a
,
uint2
b
,
uint4
c
)
{
uint4
d
;
uint4
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
// asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
// "{%4, %5, %6, %7},"
"{%8, %9},"
// "{%8, %9},"
"{%10, %11, %12, %13};
\n
"
// "{%10, %11, %12, %13};\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#else
// #else
asm
volatile
(
"{"
// asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{tmp0, tmp1, tmp2, tmp3},"
// "{tmp0, tmp1, tmp2, tmp3},"
"{%4, %5},"
// "{%4, %5},"
"{%8},"
// "{%8},"
"{%10, %11, %12, %13};
\n
"
// "{%10, %11, %12, %13};\n"
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
// "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
"{%6, %7},"
// "{%6, %7},"
"{%9},"
// "{%9},"
"{tmp0, tmp1, tmp2, tmp3};"
// "{tmp0, tmp1, tmp2, tmp3};"
"}
\n
"
// "}\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w));
#endif
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
d
;
return
d
;
}
}
...
@@ -121,7 +124,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
...
@@ -121,7 +124,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile(
// asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
// "mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
// "{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
...
@@ -130,43 +133,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
...
@@ -130,43 +133,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
// "{%10, %11, %12, %13};\n"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
// #else
asm
volatile
(
"{"
// asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp0, tmp1},"
// "{tmp0, tmp1},"
"{%4},"
// "{%4},"
"{%8},"
// "{%8},"
"{%10, %11};
\n
"
// "{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{tmp2, tmp3},"
// "{tmp2, tmp3},"
"{%5},"
// "{%5},"
"{%8},"
// "{%8},"
"{%12, %13};
\n
"
// "{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1},"
// "{%0, %1},"
"{%6},"
// "{%6},"
"{%9},"
// "{%9},"
"{tmp0, tmp1};
\n
"
// "{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.s4.s4.s32 "
"{%2, %3},"
// "{%2, %3},"
"{%7},"
// "{%7},"
"{%9},"
// "{%9},"
"{tmp2, tmp3};
\n
"
// "{tmp2, tmp3};\n"
"}
\n
"
// "}\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
// : "r"(a.x),
"r"
(
a
.
y
),
// "r"(a.y),
"r"
(
a
.
z
),
// "r"(a.z),
"r"
(
a
.
w
),
// "r"(a.w),
"r"
(
b
.
x
),
// "r"(b.x),
"r"
(
b
.
y
),
// "r"(b.y),
"r"
(
c
.
x
),
// "r"(c.x),
"r"
(
c
.
y
),
// "r"(c.y),
"r"
(
c
.
z
),
// "r"(c.z),
"r"
(
c
.
w
),
// "r"(c.w),
"n"
(
K
/
2
));
// "n"(K / 2));
#endif
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
d
;
return
d
;
}
}
...
@@ -175,7 +179,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
...
@@ -175,7 +179,7 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
uint4
d
;
uint4
d
;
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// asm volatile(
// asm volatile(
// "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
// "mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
// "{%0, %1, %2, %3},"
// "{%0, %1, %2, %3},"
...
@@ -184,43 +188,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
...
@@ -184,43 +188,44 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
// "{%10, %11, %12, %13};\n"
// "{%10, %11, %12, %13};\n"
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
// : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
// #else
asm
volatile
(
"{"
// asm volatile("{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
// ".reg .b32 tmp0, tmp1, tmp2, tmp3;"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp0, tmp1},"
// "{tmp0, tmp1},"
"{%4},"
// "{%4},"
"{%8},"
// "{%8},"
"{%10, %11};
\n
"
// "{%10, %11};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{tmp2, tmp3},"
// "{tmp2, tmp3},"
"{%5},"
// "{%5},"
"{%8},"
// "{%8},"
"{%12, %13};
\n
"
// "{%12, %13};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1},"
// "{%0, %1},"
"{%6},"
// "{%6},"
"{%9},"
// "{%9},"
"{tmp0, tmp1};
\n
"
// "{tmp0, tmp1};\n"
"mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
// "mma.sync.aligned.m8n8k%14.row.col.s32.u4.s4.s32 "
"{%2, %3},"
// "{%2, %3},"
"{%7},"
// "{%7},"
"{%9},"
// "{%9},"
"{tmp2, tmp3};
\n
"
// "{tmp2, tmp3};\n"
"}
\n
"
// "}\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
// : "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
// : "r"(a.x),
"r"
(
a
.
y
),
// "r"(a.y),
"r"
(
a
.
z
),
// "r"(a.z),
"r"
(
a
.
w
),
// "r"(a.w),
"r"
(
b
.
x
),
// "r"(b.x),
"r"
(
b
.
y
),
// "r"(b.y),
"r"
(
c
.
x
),
// "r"(c.x),
"r"
(
c
.
y
),
// "r"(c.y),
"r"
(
c
.
z
),
// "r"(c.z),
"r"
(
c
.
w
),
// "r"(c.w),
"n"
(
K
/
2
));
// "n"(K / 2));
#endif
// #endif
// printf("%s-%s-%d: asm not supportted in Hip yet!\n", __FILE__, __func__, __LINE__);
return
d
;
return
d
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment