Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
bf4adfeb
Commit
bf4adfeb
authored
Mar 26, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Support Turing (sm_75) architecture
parent
3ef186fd
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
106 deletions
+115
-106
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+111
-105
src/kernels/zgemm/gemm_w8a8.cuh
src/kernels/zgemm/gemm_w8a8.cuh
+1
-0
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+2
-0
third_party/Block-Sparse-Attention
third_party/Block-Sparse-Attention
+1
-1
No files found.
src/kernels/zgemm/gemm_w4a4.cuh
View file @
bf4adfeb
...
@@ -335,6 +335,7 @@ public:
...
@@ -335,6 +335,7 @@ public:
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
struct
gemm_w4a4_fp4_kernel
{
struct
gemm_w4a4_fp4_kernel
{
static
constexpr
int
MIN_ARCH
=
1200
;
__device__
__device__
void
operator
()(
void
operator
()(
const
packed_act_t
*
act
,
const
packed_act_t
*
act
,
...
@@ -389,67 +390,16 @@ public:
...
@@ -389,67 +390,16 @@ public:
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
)
{
static
packed_psum_t
mma
(
packed_act_t
act
,
packed_wgt_t
wgt
)
{
packed_psum_t
psum
;
packed_psum_t
psum
;
if
constexpr
(
!
ACT_UNSIGNED
)
{
uint4
out1
=
mma_m16n8kx_s32common
<
mma_helper
::
s4u4
<
ACT_UNSIGNED
>
,
mma_helper
::
s4
>
(
act
,
uint2
(
wgt
.
x
,
wgt
.
y
),
uint4
(
0
,
0
,
0
,
0
));
asm
volatile
(
uint4
out2
=
mma_m16n8kx_s32common
<
mma_helper
::
s4u4
<
ACT_UNSIGNED
>
,
mma_helper
::
s4
>
(
act
,
uint2
(
wgt
.
z
,
wgt
.
w
),
uint4
(
0
,
0
,
0
,
0
));
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
psum
.
data
[
0
]
=
out1
.
x
;
"{%0, %1, %2, %3},"
psum
.
data
[
1
]
=
out1
.
y
;
"{%4, %5, %6, %7},"
psum
.
data
[
2
]
=
out1
.
z
;
"{%8, %9},"
psum
.
data
[
3
]
=
out1
.
w
;
"{%10, %11, %12, %13};
\n
"
psum
.
data
[
4
]
=
out2
.
x
;
:
psum
.
data
[
5
]
=
out2
.
y
;
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
psum
.
data
[
6
]
=
out2
.
z
;
:
psum
.
data
[
7
]
=
out2
.
w
;
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
if
constexpr
(
ACT_UNSIGNED
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
0
]),
"=r"
(
psum
.
data
[
1
]),
"=r"
(
psum
.
data
[
2
]),
"=r"
(
psum
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
)
// "r"(psum.data[0]), "r"(psum.data[1]), "r"(psum.data[2]), "r"(psum.data[3])
);
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};
\n
"
:
"=r"
(
psum
.
data
[
4
]),
"=r"
(
psum
.
data
[
5
]),
"=r"
(
psum
.
data
[
6
]),
"=r"
(
psum
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
),
"r"
(
0
)
// "r"(psum.data[4]), "r"(psum.data[5]), "r"(psum.data[6]), "r"(psum.data[7])
);
}
return
psum
;
return
psum
;
}
}
...
@@ -554,7 +504,7 @@ public:
...
@@ -554,7 +504,7 @@ public:
// [WARP_M, WARP_N * 2] when fuse_glu
// [WARP_M, WARP_N * 2] when fuse_glu
template
<
bool
fuse_glu
>
template
<
bool
fuse_glu
>
struct
load_act_to_fpsum
{
struct
load_act_to_fpsum
{
using
matrix_t
=
half_t
[
WARP
_M
][
WARP_N
+
8
];
using
matrix_t
=
half_t
[
INSN
_M
][
WARP_N
+
8
];
static
constexpr
size_t
SHMEM_SIZE
=
sizeof
(
matrix_t
);
static
constexpr
size_t
SHMEM_SIZE
=
sizeof
(
matrix_t
);
__device__
__forceinline__
__device__
__forceinline__
...
@@ -568,41 +518,42 @@ public:
...
@@ -568,41 +518,42 @@ public:
using
packed_raw_input
=
std
::
array
<
half2_t
,
PACK_SIZE
>
;
using
packed_raw_input
=
std
::
array
<
half2_t
,
PACK_SIZE
>
;
#pragma unroll
#pragma unroll
for
(
int
row
=
0
;
row
<
WARP_M
;
row
++
)
{
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
packed_input
pack
;
#pragma unroll
// TODO: numCols not multiples of PACK_SIZE
for
(
int
row
=
0
;
row
<
INSN_M
;
row
++
)
{
if
constexpr
(
fuse_glu
)
{
packed_input
pack
;
packed_raw_input
raw
;
// TODO: numCols not multiples of PACK_SIZE
raw
.
fill
(
half2_t
(
0
,
0
));
if
constexpr
(
fuse_glu
)
{
bool
pred
=
row
<
maxRows
&&
laneId
*
PACK_SIZE
*
2
<
maxCols
;
packed_raw_input
raw
;
if
(
pred
)
{
raw
.
fill
(
half2_t
(
0
,
0
));
raw
=
load
(
reinterpret_cast
<
const
packed_raw_input
*>
(
input
+
row
*
stride
+
laneId
*
PACK_SIZE
*
2
));
bool
pred
=
(
m
*
INSN_M
+
row
)
<
maxRows
&&
laneId
*
PACK_SIZE
*
2
<
maxCols
;
}
if
(
pred
)
{
#pragma unroll
raw
=
load
(
reinterpret_cast
<
const
packed_raw_input
*>
(
input
+
(
m
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
*
2
));
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
}
pack
[
j
]
=
raw
[
j
].
x
*
silu
(
raw
[
j
].
y
);
#pragma unroll
}
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
}
else
{
pack
[
j
]
=
raw
[
j
].
x
*
silu
(
raw
[
j
].
y
);
pack
.
fill
(
half_t
(
0
));
}
bool
pred
=
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
}
else
{
if
(
pred
)
{
pack
.
fill
(
half_t
(
0
));
pack
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
row
*
stride
+
laneId
*
PACK_SIZE
));
bool
pred
=
(
m
*
INSN_M
+
row
)
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
pack
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
(
m
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
));
}
}
}
store
<
true
>
(
reinterpret_cast
<
packed_input
*>
(
&
mat
[
row
][
laneId
*
PACK_SIZE
]),
pack
);
}
}
store
<
true
>
(
reinterpret_cast
<
packed_input
*>
(
&
mat
[
row
][
laneId
*
PACK_SIZE
]),
pack
);
__syncwarp
();
}
__syncwarp
();
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
for
(
int
n
=
0
;
n
<
WARP_N_TILES
;
n
++
)
{
for
(
int
n
=
0
;
n
<
WARP_N_TILES
;
n
++
)
{
const
int
row
=
m
*
INSN_M
+
laneId
%
16
;
const
int
row
=
laneId
%
16
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
uint4
tmp
;
uint4
tmp
;
ldmatrix
(
&
mat
[
row
][
col
],
tmp
);
ldmatrix
(
&
mat
[
row
][
col
],
tmp
);
*
reinterpret_cast
<
uint4
*>
(
&
out
[
m
*
WARP_N_TILES
+
n
])
=
tmp
;
*
reinterpret_cast
<
uint4
*>
(
&
out
[
m
*
WARP_N_TILES
+
n
])
=
tmp
;
}
}
__syncwarp
();
}
}
__syncwarp
();
}
}
};
};
...
@@ -707,6 +658,7 @@ public:
...
@@ -707,6 +658,7 @@ public:
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
// each thread block (1 warp) quantize WARP_M * WARP_K tile (32 * 64)
struct
quantize_w4a4_act_kernel
{
struct
quantize_w4a4_act_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
__device__
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
)
{
void
operator
()(
const
half_t
*
input
,
packed_act_t
*
output
,
packed_ascale_t
*
oscales
,
int
K
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
@@ -744,6 +696,7 @@ public:
...
@@ -744,6 +696,7 @@ public:
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
// each thread block (1 warp) quantize WARP_N * WARP_K tile (128 * 64)
struct
quantize_w4a4_wgt_kernel
{
struct
quantize_w4a4_wgt_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
__device__
void
operator
()(
const
half_t
*
input
,
packed_wgt_t
*
output
,
packed_wscale_t
*
oscales
,
int
K
)
{
void
operator
()(
const
half_t
*
input
,
packed_wgt_t
*
output
,
packed_wscale_t
*
oscales
,
int
K
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
@@ -777,10 +730,54 @@ public:
...
@@ -777,10 +730,54 @@ public:
}
}
};
};
struct
i2f_sm80
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
return
make_float2
(
int2float_fast
(
x
),
int2float_fast
(
y
));
}
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
return
float22half2
<
half2_t
>
(
int2float2
(
x
,
y
));
}
};
struct
i2f_sm75
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
return
make_float2
(
int2float_fast
(
x
),
int2float_fast
(
y
));
}
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
return
half2
(
__int2half_rn
(
x
),
__int2half_rn
(
y
));
}
};
struct
i2f_sm75_fast
{
__device__
__forceinline__
static
float2
int2float2
(
int
x
,
int
y
)
{
return
make_float2
(
int2float_fast
(
x
),
int2float_fast
(
y
));
}
__device__
__forceinline__
static
half2_t
int2half2
(
int
x
,
int
y
)
{
return
int2half2_fast_512
(
x
,
y
);
}
};
template
<
bool
ACT_UNSIGNED
,
typename
T
>
template
<
bool
ACT_UNSIGNED
,
typename
T
>
__device__
__forceinline__
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
apply_scales
<
true
>
([
&
](
int
i
,
int
j
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
using
int2half2
=
i2f_sm80
;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
using
int2half2
=
std
::
conditional_t
<
Config
::
FASTER_I2F
,
i2f_sm75_fast
,
i2f_sm75
>
;;
#else
using
int2half2
=
Base
::
i2f_normal
;
#endif
Base
::
template
apply_scales
<
int2half2
>([
&
](
int
i
,
int
j
)
{
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
},
ascale
,
wscale
,
fpsum
);
},
ascale
,
wscale
,
fpsum
);
}
}
...
@@ -875,7 +872,7 @@ public:
...
@@ -875,7 +872,7 @@ public:
}
}
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
// out: [M / BLOCK_M, N / BLOCK_N, NUM_WARPS, 1, NUM_M_TILES, NUM_N_TILES, WARP_SIZE] of fpsum_warp
template
<
typename
Epilogue
,
bool
ACT_UNSIGNED
>
template
<
typename
Epilogue
,
bool
ACT_UNSIGNED
,
bool
USE_FP32_ACCUM
>
__device__
__forceinline__
__device__
__forceinline__
static
void
gemm_w4a4_block
(
static
void
gemm_w4a4_block
(
const
BlockInfo
binfo
,
const
BlockInfo
binfo
,
...
@@ -902,7 +899,7 @@ public:
...
@@ -902,7 +899,7 @@ public:
wgt_warp
W
[
NUM_STAGES
];
// 32
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
[
NUM_STAGES
];
// 1
ascale_warp
ascale
[
NUM_STAGES
];
// 1
wscale_warp
wscale
[
NUM_STAGES
];
// 2
wscale_warp
wscale
[
NUM_STAGES
];
// 2
fpsum_warp
fpsum
;
// 64
std
::
conditional_t
<
USE_FP32_ACCUM
,
f32psum_warp
,
fpsum_warp
>
fpsum
;
// 64
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<true>(wscales, wscale[0], true);
// load_wscale<false>(wscales, wscale[1], true);
// load_wscale<false>(wscales, wscale[1], true);
...
@@ -916,16 +913,16 @@ public:
...
@@ -916,16 +913,16 @@ public:
}
}
for
(
auto
&
pack
:
fpsum
)
{
for
(
auto
&
pack
:
fpsum
)
{
#if 1
if
constexpr
(
USE_FP32_ACCUM
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
].
x
=
0
;
pack
.
data
[
i
]
=
0
;
pack
.
data
[
i
].
y
=
0
;
}
}
}
else
{
#else
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
].
x
=
0
;
pack
.
data
[
i
]
=
0
;
pack
.
data
[
i
].
y
=
0
;
}
}
}
#endif
}
}
int
dummy
=
0
;
int
dummy
=
0
;
...
@@ -949,9 +946,11 @@ public:
...
@@ -949,9 +946,11 @@ public:
compute
<
ACT_UNSIGNED
>
(
A
[
k2
],
W
[
k2
],
ascale
[
k2
],
wscale
[
k2
],
fpsum
);
compute
<
ACT_UNSIGNED
>
(
A
[
k2
],
W
[
k2
],
ascale
[
k2
],
wscale
[
k2
],
fpsum
);
//#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
if
(
alwaysfalse
)
{
if
(
alwaysfalse
)
{
dummy
=
clock
();
dummy
=
clock
();
}
}
//#endif
// asm volatile ("membar.cta;");
// asm volatile ("membar.cta;");
}
}
...
@@ -961,11 +960,12 @@ public:
...
@@ -961,11 +960,12 @@ public:
#endif
#endif
#if 0
fpsum_warp
f16psum
;
auto f16psum = packed_fp32_to_fp16(fpsum);
if
constexpr
(
USE_FP32_ACCUM
)
{
#else
f16psum
=
packed_fp32_to_fp16
(
fpsum
);
auto
f16psum
=
fpsum
;
}
else
{
#endif
f16psum
=
fpsum
;
}
CHECK_NAN
(
f16psum
,
"f16psum"
);
CHECK_NAN
(
f16psum
,
"f16psum"
);
...
@@ -1324,6 +1324,7 @@ public:
...
@@ -1324,6 +1324,7 @@ public:
struct
quantize_w4a4_fuse_lora_kernel
{
struct
quantize_w4a4_fuse_lora_kernel
{
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
...
@@ -2059,6 +2060,7 @@ public:
...
@@ -2059,6 +2060,7 @@ public:
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct
vk_mul_q_kernel
{
struct
vk_mul_q_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
// FIXME FIXME FIXME
// FIXME FIXME FIXME
__device__
__device__
void
operator
()(
half_t
*
q
,
const
float
*
vk
,
float
eps
,
int
num_tokens
)
{
void
operator
()(
half_t
*
q
,
const
float
*
vk
,
float
eps
,
int
num_tokens
)
{
...
@@ -2116,6 +2118,9 @@ public:
...
@@ -2116,6 +2118,9 @@ public:
template
<
typename
Epilogue
,
bool
ACT_UNSIGNED
>
template
<
typename
Epilogue
,
bool
ACT_UNSIGNED
>
struct
gemm_w4a4_kernel
{
struct
gemm_w4a4_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
int
MAX_ARCH
=
Config
::
FASTER_I2F
?
750
:
INT_MAX
;
// FASTER_I2F is only needed on sm_75
__device__
__device__
void
operator
()(
void
operator
()(
const
packed_act_t
*
act
,
const
packed_act_t
*
act
,
...
@@ -2146,7 +2151,7 @@ public:
...
@@ -2146,7 +2151,7 @@ public:
// bool fusequant = !out;
// bool fusequant = !out;
gemm_w4a4_block
<
Epilogue
,
ACT_UNSIGNED
>
(
gemm_w4a4_block
<
Epilogue
,
ACT_UNSIGNED
,
false
>
(
binfo
,
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
...
@@ -2164,6 +2169,7 @@ public:
...
@@ -2164,6 +2169,7 @@ public:
template
<
typename
Epilogue
>
template
<
typename
Epilogue
>
struct
test_epilogue_kernel
{
struct
test_epilogue_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
false
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
false
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
...
...
src/kernels/zgemm/gemm_w8a8.cuh
View file @
bf4adfeb
...
@@ -448,6 +448,7 @@ public:
...
@@ -448,6 +448,7 @@ public:
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
// out : [M / BLOCK_M, BLOCK_M, N / BLOCK_N, BLOCK_N]
template
<
typename
Epilogue
>
template
<
typename
Epilogue
>
struct
gemm_w8a8_kernel
{
struct
gemm_w8a8_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
__device__
__device__
void
operator
()(
void
operator
()(
const
packed_act_t
*
act
,
const
packed_act_t
*
act
,
...
...
src/kernels/zgemm/zgemm.h
View file @
bf4adfeb
...
@@ -69,6 +69,8 @@ void attention_fp16(
...
@@ -69,6 +69,8 @@ void attention_fp16(
float
scale
float
scale
);
);
// EXPERIMENTAL, for sm_75
void
set_faster_i2f_mode
(
std
::
string
mode
);
// FOR TEST ONLY
// FOR TEST ONLY
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
);
void
test_rmsnorm_rope
(
Tensor
input
,
Tensor
output
,
Tensor
norm_q
,
Tensor
norm_k
,
Tensor
rotary_emb
);
...
...
Block-Sparse-Attention
@
99511c34
Compare
0d23f715
...
99511c34
Subproject commit
0d23f715690c5171fd93679de8afd149376db167
Subproject commit
99511c34554a13ffaa81321834faf66389ffcb30
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment