Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
181f4e43
Commit
181f4e43
authored
Nov 24, 2025
by
fengzch
Browse files
fix compile error
parent
4cdcd76f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
133 additions
and
99 deletions
+133
-99
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+6
-6
src/kernels/zgemm/epilogues.cuh
src/kernels/zgemm/epilogues.cuh
+21
-3
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+28
-13
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+60
-59
src/kernels/zgemm/mma.cuh
src/kernels/zgemm/mma.cuh
+1
-1
src/kernels/zgemm/mma_earlycuda.cuh
src/kernels/zgemm/mma_earlycuda.cuh
+17
-17
No files found.
src/kernels/awq/gemv_awq.cu
View file @
181f4e43
...
@@ -106,12 +106,12 @@ __global__ void gemv_kernel(const half_t *inputs,
...
@@ -106,12 +106,12 @@ __global__ void gemv_kernel(const half_t *inputs,
const
int
IC
,
const
int
IC
,
const
int
OC
)
{
const
int
OC
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if
constexpr
(
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
)
{
//
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch
();
//
trap_unsupported_arch();
return
;
//
return;
}
//
}
#endif
//
#endif
using
half2_t
=
typename
packed_as
<
half_t
,
2
>::
type
;
using
half2_t
=
typename
packed_as
<
half_t
,
2
>::
type
;
using
accum_t
=
float
;
using
accum_t
=
float
;
using
accum2_t
=
typename
packed_as
<
accum_t
,
2
>::
type
;
using
accum2_t
=
typename
packed_as
<
accum_t
,
2
>::
type
;
...
...
src/kernels/zgemm/epilogues.cuh
View file @
181f4e43
...
@@ -608,7 +608,16 @@ public:
...
@@ -608,7 +608,16 @@ public:
packed_fpsum_t
v
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
packed_fpsum_t
v
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
LITELA_HEAD_DIM
/
16
+
tile_v
];
LITELA_HEAD_DIM
/
16
+
tile_v
];
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
k
.
data
[
j
]
=
__hmax2
(
k
.
data
[
j
],
half2_t
(
0
,
0
));
// relu
__hip_bfloat162
first
;
first
.
x
=
float
(
k
.
data
[
j
].
x
);
first
.
y
=
float
(
k
.
data
[
j
].
y
);
auto
temp
=
half2_t
(
0
,
0
);
__hip_bfloat162
sec
;
sec
.
x
=
float
(
temp
.
x
);
sec
.
y
=
float
(
temp
.
y
);
auto
relu_result
=
__hmax2
(
first
,
sec
);
// relu
k
.
data
[
j
].
x
=
float
(
relu_result
.
x
);
k
.
data
[
j
].
y
=
float
(
relu_result
.
y
);
}
}
attn_sum
=
mma_litela
(
k
,
v
,
attn_sum
);
attn_sum
=
mma_litela
(
k
,
v
,
attn_sum
);
}
}
...
@@ -632,7 +641,16 @@ public:
...
@@ -632,7 +641,16 @@ public:
packed_fpsum_t
k
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
tile_k
];
packed_fpsum_t
k
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
tile_k
];
packed_fpsum_t
v
=
{};
packed_fpsum_t
v
=
{};
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
k
.
data
[
j
]
=
__hmax2
(
k
.
data
[
j
],
half2_t
(
0
,
0
));
// relu
__hip_bfloat162
first
;
first
.
x
=
float
(
k
.
data
[
j
].
x
);
first
.
y
=
float
(
k
.
data
[
j
].
y
);
auto
temp
=
half2_t
(
0
,
0
);
__hip_bfloat162
sec
;
sec
.
x
=
float
(
temp
.
x
);
sec
.
y
=
float
(
temp
.
y
);
auto
relu_result
=
__hmax2
(
first
,
sec
);
// relu
k
.
data
[
j
].
x
=
float
(
relu_result
.
x
);
k
.
data
[
j
].
y
=
float
(
relu_result
.
y
);
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
@@ -801,7 +819,7 @@ public:
...
@@ -801,7 +819,7 @@ public:
fpsum_warp
fpsum
;
fpsum_warp
fpsum
;
Base
::
template
load_act_to_fpsum
<
false
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
typename
Base
::
template
load_act_to_fpsum
<
false
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
args
.
actualN
-
n_offset
,
...
...
src/kernels/zgemm/gemm_base.cuh
View file @
181f4e43
#pragma once
#pragma once
#include <hip/amd_detail/amd_hip_bf16.h>
#include "common.h"
#include "common.h"
...
@@ -8,7 +9,7 @@
...
@@ -8,7 +9,7 @@
#include "mma_earlycuda.cuh"
#include "mma_earlycuda.cuh"
#pragma nv_diag_suppress 177
#pragma nv_diag_suppress 177
#define __DTK_ARCH__ 1200
#ifdef _MSC_VER
#ifdef _MSC_VER
#define ALWAYSINLINE [[msvc::forceinline]]
#define ALWAYSINLINE [[msvc::forceinline]]
#else
#else
...
@@ -208,11 +209,11 @@ public:
...
@@ -208,11 +209,11 @@ public:
uint4
out1
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
uint4
out1
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
kernels
::
bit_cast
<
uint4
>
(
a
),
kernels
::
bit_cast
<
uint4
>
(
a
),
kernels
::
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
(
b
.
data
[
0
],
b
.
data
[
1
]
)
),
kernels
::
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
{
b
.
data
[
0
],
b
.
data
[
1
]
}
),
kernels
::
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
0
],
psum
.
data
[
1
],
psum
.
data
[
2
],
psum
.
data
[
3
])));
kernels
::
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
0
],
psum
.
data
[
1
],
psum
.
data
[
2
],
psum
.
data
[
3
])));
uint4
out2
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
uint4
out2
=
mma_m16n8k16_f32f16f16f32
<
is_bf16
>
(
kernels
::
bit_cast
<
uint4
>
(
a
),
kernels
::
bit_cast
<
uint4
>
(
a
),
kernels
::
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
(
b
.
data
[
2
],
b
.
data
[
3
]
)
),
kernels
::
bit_cast
<
uint2
>
(
std
::
array
<
half2_t
,
2
>
{
b
.
data
[
2
],
b
.
data
[
3
]
}
),
kernels
::
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
4
],
psum
.
data
[
5
],
psum
.
data
[
6
],
psum
.
data
[
7
])));
kernels
::
bit_cast
<
uint4
>
(
float4
(
psum
.
data
[
4
],
psum
.
data
[
5
],
psum
.
data
[
6
],
psum
.
data
[
7
])));
psum
.
data
[
0
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
x
);
psum
.
data
[
0
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
x
);
psum
.
data
[
1
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
y
);
psum
.
data
[
1
]
=
kernels
::
bit_cast
<
float
>
(
out1
.
y
);
...
@@ -344,14 +345,28 @@ public:
...
@@ -344,14 +345,28 @@ public:
const
int
packIdx
=
k
/
(
WSCALES_PACK_SIZE
*
WARP_SIZE
);
const
int
packIdx
=
k
/
(
WSCALES_PACK_SIZE
*
WARP_SIZE
);
const
int
srcLane
=
4
*
(
k
/
WSCALES_PACK_SIZE
)
+
laneId
%
4
;
const
int
srcLane
=
4
*
(
k
/
WSCALES_PACK_SIZE
)
+
laneId
%
4
;
const
int
elementIdx
=
k
%
WSCALES_PACK_SIZE
/
2
;
const
int
elementIdx
=
k
%
WSCALES_PACK_SIZE
/
2
;
return
__shfl_sync
(
~
0
,
block
[
packIdx
].
data
[
elementIdx
],
srcLane
);
half2
temp
;
temp
.
x
=
float
(
block
[
packIdx
].
data
[
elementIdx
].
x
);
temp
.
y
=
float
(
block
[
packIdx
].
data
[
elementIdx
].
y
);
half2
res
=
__shfl
(
temp
,
srcLane
);
half2_t
result
;
result
.
x
=
(
float
)
res
.
x
;
result
.
y
=
(
float
)
res
.
y
;
return
result
;
}
}
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
// get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
__device__
__forceinline__
static
half2_t
broadcast_ascale
(
ascale_warp
block
,
int
k
,
int
laneId
)
{
__device__
__forceinline__
static
half2_t
broadcast_ascale
(
ascale_warp
block
,
int
k
,
int
laneId
)
{
const
int
packIdx
=
k
/
(
ASCALES_PACK_SIZE
*
WARP_SIZE
);
const
int
packIdx
=
k
/
(
ASCALES_PACK_SIZE
*
WARP_SIZE
);
const
int
srcLane
=
8
*
(
k
/
ASCALES_PACK_SIZE
)
+
laneId
/
4
;
const
int
srcLane
=
8
*
(
k
/
ASCALES_PACK_SIZE
)
+
laneId
/
4
;
const
int
elementIdx
=
k
%
ASCALES_PACK_SIZE
/
2
;
const
int
elementIdx
=
k
%
ASCALES_PACK_SIZE
/
2
;
return
__shfl_sync
(
~
0
,
block
[
packIdx
].
data
[
elementIdx
],
srcLane
);
half2
temp
;
temp
.
x
=
float
(
block
[
packIdx
].
data
[
elementIdx
].
x
);
temp
.
y
=
float
(
block
[
packIdx
].
data
[
elementIdx
].
y
);
half2
res
=
__shfl
(
temp
,
srcLane
);
half2_t
result
;
result
.
x
=
(
float
)
res
.
x
;
result
.
y
=
(
float
)
res
.
y
;
return
result
;
}
}
struct
i2f_normal
{
struct
i2f_normal
{
...
@@ -897,16 +912,16 @@ constexpr int max_arch() {
...
@@ -897,16 +912,16 @@ constexpr int max_arch() {
template
<
typename
kernel
,
typename
...
T
>
template
<
typename
kernel
,
typename
...
T
>
__global__
static
void
invoke_kernel
(
T
...
args
)
{
__global__
static
void
invoke_kernel
(
T
...
args
)
{
#ifdef __CUDA_ARCH__
//
#ifdef __CUDA_ARCH__
if
constexpr
(
__CUDA_ARCH__
>=
min_arch
<
kernel
>
()
&&
__CUDA_ARCH__
<=
max_arch
<
kernel
>
())
{
//
if constexpr (__CUDA_ARCH__ >= min_arch<kernel>() && __CUDA_ARCH__ <= max_arch<kernel>()) {
kernel
()(
args
...);
//
kernel()(args...);
}
else
{
//
} else {
trap_unsupported_arch
();
//
trap_unsupported_arch();
}
//
}
#else
//
#else
// ???
// ???
kernel
()(
args
...);
kernel
()(
args
...);
#endif
//
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
src/kernels/zgemm/gemm_w4a4.cuh
View file @
181f4e43
...
@@ -122,8 +122,8 @@ public:
...
@@ -122,8 +122,8 @@ public:
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxvalue
[
0
][
i
],
__shfl_xor
(
maxvalue
[
0
][
i
],
mask
));
maxvalue
[
0
][
i
]
=
__hmax
(
maxvalue
[
0
][
i
],
__shfl_xor
(
float
(
maxvalue
[
0
][
i
]
)
,
mask
));
maxvalue
[
1
][
i
]
=
__hmax
(
maxvalue
[
1
][
i
],
__shfl_xor
(
maxvalue
[
1
][
i
],
mask
));
maxvalue
[
1
][
i
]
=
__hmax
(
maxvalue
[
1
][
i
],
__shfl_xor
(
float
(
maxvalue
[
1
][
i
]
)
,
mask
));
}
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
...
@@ -197,56 +197,57 @@ public:
...
@@ -197,56 +197,57 @@ public:
int
ida
,
int
ida
,
int
idb
)
{
int
idb
)
{
packed_f32psum_t
out
;
packed_f32psum_t
out
;
asm
volatile
(
// asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
// "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
// "{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
// "{%4, %5, %6, %7}, "
"{%8, %9}, "
// "{%8, %9}, "
"{%10, %11, %12, %13}, "
// "{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
// "{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
// "{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
0
]),
"=f"
(
out
.
data
[
1
]),
"=f"
(
out
.
data
[
2
]),
"=f"
(
out
.
data
[
3
])
// : "=f"(out.data[0]), "=f"(out.data[1]), "=f"(out.data[2]), "=f"(out.data[3])
:
"r"
(
act
.
x
),
// : "r"(act.x),
"r"
(
act
.
y
),
// "r"(act.y),
"r"
(
act
.
z
),
// "r"(act.z),
"r"
(
act
.
w
),
// "r"(act.w),
"r"
(
wgt
.
x
),
// "r"(wgt.x),
"r"
(
wgt
.
y
),
// "r"(wgt.y),
"f"
(
psum
.
data
[
0
]),
// "f"(psum.data[0]),
"f"
(
psum
.
data
[
1
]),
// "f"(psum.data[1]),
"f"
(
psum
.
data
[
2
]),
// "f"(psum.data[2]),
"f"
(
psum
.
data
[
3
]),
// "f"(psum.data[3]),
"r"
(
amscale
),
// "r"(amscale),
"n"
(
0
),
// "n"(0),
"h"
((
short
)
ida
),
// "h"((short)ida),
"r"
(
wmscale
),
// "r"(wmscale),
"n"
(
0
),
// "n"(0),
"h"
((
short
)(
idb
*
2
)));
// "h"((short)(idb * 2)));
asm
volatile
(
// asm volatile(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
// "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
// "{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
// "{%4, %5, %6, %7}, "
"{%8, %9}, "
// "{%8, %9}, "
"{%10, %11, %12, %13}, "
// "{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
// "{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
// "{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
4
]),
"=f"
(
out
.
data
[
5
]),
"=f"
(
out
.
data
[
6
]),
"=f"
(
out
.
data
[
7
])
// : "=f"(out.data[4]), "=f"(out.data[5]), "=f"(out.data[6]), "=f"(out.data[7])
:
"r"
(
act
.
x
),
// : "r"(act.x),
"r"
(
act
.
y
),
// "r"(act.y),
"r"
(
act
.
z
),
// "r"(act.z),
"r"
(
act
.
w
),
// "r"(act.w),
"r"
(
wgt
.
z
),
// "r"(wgt.z),
"r"
(
wgt
.
w
),
// "r"(wgt.w),
"f"
(
psum
.
data
[
4
]),
// "f"(psum.data[4]),
"f"
(
psum
.
data
[
5
]),
// "f"(psum.data[5]),
"f"
(
psum
.
data
[
6
]),
// "f"(psum.data[6]),
"f"
(
psum
.
data
[
7
]),
// "f"(psum.data[7]),
"r"
(
amscale
),
// "r"(amscale),
"n"
(
0
),
// "n"(0),
"h"
((
short
)
ida
),
// "h"((short)ida),
"r"
(
wmscale
),
// "r"(wmscale),
"n"
(
0
),
// "n"(0),
"h"
((
short
)(
idb
*
2
+
1
)));
// "h"((short)(idb * 2 + 1)));
std
::
cout
<<
__func__
<<
"mma_fp4 is not implemented for HIP yet[asm error!!!]"
<<
std
::
endl
;
return
out
;
return
out
;
}
}
...
@@ -465,11 +466,11 @@ public:
...
@@ -465,11 +466,11 @@ public:
}
}
#pragma unroll
#pragma unroll
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
maxvalue
[
0
]
=
__hmax
(
maxvalue
[
0
],
__shfl_xor
(
maxvalue
[
0
],
mask
));
maxvalue
[
0
]
=
__hmax
(
maxvalue
[
0
],
__shfl_xor
(
float
(
maxvalue
[
0
]
)
,
mask
));
maxvalue
[
1
]
=
__hmax
(
maxvalue
[
1
],
__shfl_xor
(
maxvalue
[
1
],
mask
));
maxvalue
[
1
]
=
__hmax
(
maxvalue
[
1
],
__shfl_xor
(
float
(
maxvalue
[
1
]
)
,
mask
));
}
}
maxvalue
[
0
]
=
__shfl
_sync
(
~
0
,
maxvalue
[
0
],
laneId
/
4
*
4
);
maxvalue
[
0
]
=
__shfl
(
float
(
maxvalue
[
0
]
)
,
laneId
/
4
*
4
);
maxvalue
[
1
]
=
__shfl
_sync
(
~
0
,
maxvalue
[
1
],
laneId
/
4
*
4
);
maxvalue
[
1
]
=
__shfl
(
float
(
maxvalue
[
1
]
)
,
laneId
/
4
*
4
);
float
scale
[
2
];
float
scale
[
2
];
// scale[0] = float(maxvalue[0]) / QVALUE_MAX;
// scale[0] = float(maxvalue[0]) / QVALUE_MAX;
...
@@ -577,14 +578,14 @@ public:
...
@@ -577,14 +578,14 @@ public:
for
(
int
mask
=
NUM_PACKS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
NUM_PACKS_PER_ROW
/
2
;
mask
>
0
;
mask
/=
2
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
maxvalue
[
i
]
=
__hmax
(
maxvalue
[
i
],
__shfl_xor
(
maxvalue
[
i
],
mask
));
maxvalue
[
i
]
=
__hmax
(
maxvalue
[
i
],
__shfl_xor
(
float
(
maxvalue
[
i
]
)
,
mask
));
}
}
}
}
// broadcast (max)
// broadcast (max)
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_PACKWARPS
;
i
++
)
{
maxvalue
[
i
]
=
__shfl
_sync
(
~
0
,
maxvalue
[
i
],
laneId
/
NUM_PACKS_PER_ROW
*
NUM_PACKS_PER_ROW
);
maxvalue
[
i
]
=
__shfl
(
float
(
maxvalue
[
i
]
)
,
laneId
/
NUM_PACKS_PER_ROW
*
NUM_PACKS_PER_ROW
);
}
}
// quantize
// quantize
...
@@ -1150,7 +1151,7 @@ public:
...
@@ -1150,7 +1151,7 @@ public:
fpsum_warp
fpsum
;
fpsum_warp
fpsum
;
Base
::
template
load_act_to_fpsum
<
fuse_glu
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
typename
Base
::
template
load_act_to_fpsum
<
fuse_glu
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
args
.
actualN
-
n_offset
,
...
...
src/kernels/zgemm/mma.cuh
View file @
181f4e43
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include <cstdint>
#include <cstdint>
#include "common.h"
#include "common.h"
#define __DTK_ARCH__ 1200
// only supports cuda 12.5+
// only supports cuda 12.5+
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
...
...
src/kernels/zgemm/mma_earlycuda.cuh
View file @
181f4e43
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#include <cstdint>
#include <cstdint>
#include "common.h"
#include "common.h"
#define __DTK_ARCH__ 1200
// cuda 12.4- does not support "C" constraint in inline assembly :(
// cuda 12.4- does not support "C" constraint in inline assembly :(
// use explicit specialization for now
// use explicit specialization for now
...
@@ -122,14 +122,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
...
@@ -122,14 +122,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::s4, mma_helpe
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
//
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
//
"mma.sync.aligned.m16n8k%14.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
//
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
//
"{%4, %5, %6, %7},"
"{%8, %9},"
//
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
//
"{%10, %11, %12, %13};\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
//
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
//
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
...
@@ -176,14 +176,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
...
@@ -176,14 +176,14 @@ __device__ __forceinline__ uint4 mma_m16n8kx_s32common<mma_helper::u4, mma_helpe
static
constexpr
int
K
=
64
;
static
constexpr
int
K
=
64
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
//
asm volatile(
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
//
"mma.sync.aligned.m16n8k%14.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
//
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
//
"{%4, %5, %6, %7},"
"{%8, %9},"
//
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
//
"{%10, %11, %12, %13};\n"
:
"=r"
(
d
.
x
),
"=r"
(
d
.
y
),
"=r"
(
d
.
z
),
"=r"
(
d
.
w
)
//
: "=r"(d.x), "=r"(d.y), "=r"(d.z), "=r"(d.w)
:
"r"
(
a
.
x
),
"r"
(
a
.
y
),
"r"
(
a
.
z
),
"r"
(
a
.
w
),
"r"
(
b
.
x
),
"r"
(
b
.
y
),
"r"
(
c
.
x
),
"r"
(
c
.
y
),
"r"
(
c
.
z
),
"r"
(
c
.
w
),
"n"
(
K
));
//
: "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), "r"(b.x), "r"(b.y), "r"(c.x), "r"(c.y), "r"(c.z), "r"(c.w), "n"(K));
#else
#else
asm
volatile
(
"{"
asm
volatile
(
"{"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
".reg .b32 tmp0, tmp1, tmp2, tmp3;"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment