Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e05256c8
Commit
e05256c8
authored
Apr 01, 2025
by
sxtyzhangzk
Committed by
Zhekai Zhang
Apr 01, 2025
Browse files
Flexible lora ranks
parent
618f3078
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1393 additions
and
1216 deletions
+1393
-1216
src/kernels/zgemm/epilogues.cuh
src/kernels/zgemm/epilogues.cuh
+803
-0
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+72
-5
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+8
-0
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+45
-1113
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+3
-8
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+98
-88
src/kernels/zgemm/gemm_w4a4_test.cu
src/kernels/zgemm/gemm_w4a4_test.cu
+3
-2
src/kernels/zgemm/lora.cuh
src/kernels/zgemm/lora.cuh
+361
-0
No files found.
src/kernels/zgemm/epilogues.cuh
0 → 100644
View file @
e05256c8
#pragma once
#include "gemm_base.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
class
Epilogues
;
#ifndef __INTELLISENSE__
template
<
typename
Config
>
class
Epilogues
:
public
GEMMBase
<
Config
>
{
#else
template
<
>
class
Epilogues
<
GEMMConfig_W4A4_FP16
>
:
public
GEMMBase
<
GEMMConfig_W4A4_FP16
>
{
using
Config
=
GEMMConfig_W4A4_FP16
;
#endif
public:
IMPORT_GEMM_BASE
(
Config
);
public:
struct
EpilogueGelu
{
struct
Arguments
{
size_t
unused
;
};
// static constexpr float SHIFT_VALUE = 0.171875f;
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
half2_t
&
data
=
fpsum
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
data
=
gelu_half2
(
data
);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
};
// template<int PoolSize = 128>
struct
EpilogueQKVProj
{
struct
Arguments
{
half_t
*
out
;
int
actualM
,
actualN
;
half_t
*
pool_out
;
// [M / PoolSize, N]
const
float
*
rotary_emb
;
// [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const
half_t
*
rmsnorm_weight_q
;
// [HEAD_DIM]
const
half_t
*
rmsnorm_weight_k
;
// [HEAD_DIM]
float
epsilon
;
};
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
NUM_HEADS_PER_WARP
=
WARP_N
/
HEAD_DIM
;
static
constexpr
int
PoolSize
=
128
;
static
constexpr
int
NUM_WARPS_PER_POOL
=
PoolSize
/
WARP_M
;
static
constexpr
int
NUM_POOLS_PER_BLOCK
=
BLOCK_M
/
PoolSize
;
static
constexpr
int
ROTARY_EMB_NUM_ELEMENTS
=
2
;
// 1 for theta, 2 for {sin, cos} pair
__device__
__forceinline__
static
void
apply
(
fpsum_warp
fpsum
,
half_t
*
out
,
int
M
,
int
N
,
int
K
,
half_t
*
pool_out
,
const
float
*
rotary_emb
,
const
half_t
*
rmsnorm_weight
,
float
epsilon
,
int
maxRows
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
alignas
(
128
)
uint8_t
shmem
[
NUM_WARPS
][
ceilDiv
(
unpack_fpsum
::
SHMEM_SIZE
,
128
)
*
128
];
constexpr
int
PACK_SIZE
=
unpack_fpsum
::
PACK_SIZE
;
using
pack_t
=
unpack_fpsum
::
pack_t
;
using
pack_rope_t
=
std
::
array
<
float
,
PACK_SIZE
/
2
*
ROTARY_EMB_NUM_ELEMENTS
>
;
constexpr
int
LANES_PER_HEAD
=
HEAD_DIM
/
PACK_SIZE
;
pack_t
reduce_tmp
;
__shared__
alignas
(
128
)
pack_t
pool
[
NUM_WARPS
];
// load rmsnorm scales
pack_t
rms
;
if
(
laneId
<
LANES_PER_HEAD
)
{
rms
=
load
(
reinterpret_cast
<
const
pack_t
*>
(
&
rmsnorm_weight
[
laneId
*
PACK_SIZE
]));
}
if
constexpr
(
LANES_PER_HEAD
<
WARP_SIZE
)
{
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
rms
[
i
]
=
__shfl_sync
(
~
0
,
rms
[
i
],
laneId
%
LANES_PER_HEAD
);
}
}
const
float
*
rotary_emb_base_addr
=
&
rotary_emb
[(
warpId
*
WARP_M
)
*
HEAD_DIM
/
2
*
ROTARY_EMB_NUM_ELEMENTS
+
laneId
*
PACK_SIZE
/
2
*
ROTARY_EMB_NUM_ELEMENTS
];
CHECK_NAN
(
fpsum
,
"fpsum"
);
unpack_fpsum
()(
fpsum
,
out
+
warpId
*
WARP_M
*
N
,
N
,
maxRows
-
warpId
*
WARP_M
,
INT_MAX
,
shmem
[
warpId
],
[
&
](
int
rowId
,
pack_t
&
pack
)
ALWAYSINLINE
{
// load rope
pack_rope_t
rope
;
if
(
laneId
<
LANES_PER_HEAD
)
{
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope
=
load
(
reinterpret_cast
<
const
pack_rope_t
*>
(
&
rotary_emb_base_addr
[
rowId
*
HEAD_DIM
/
2
*
ROTARY_EMB_NUM_ELEMENTS
]));
}
if
constexpr
(
LANES_PER_HEAD
<
WARP_SIZE
)
{
for
(
int
i
=
0
;
i
<
rope
.
size
();
i
++
)
{
rope
[
i
]
=
__shfl_sync
(
~
0
,
rope
[
i
],
laneId
%
LANES_PER_HEAD
);
}
}
// rmsnorm
float
sqrsum
=
0.0
f
;
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
sqrsum
+=
float
(
pack
[
i
])
*
float
(
pack
[
i
]);
CHECK_NAN
(
sqrsum
,
"sqrsum"
);
}
#pragma unroll
for
(
int
mask
=
LANES_PER_HEAD
/
2
;
mask
>
0
;
mask
/=
2
)
{
sqrsum
+=
__shfl_xor_sync
(
~
0
,
sqrsum
,
mask
);
}
sqrsum
/=
HEAD_DIM
;
float
coef
=
cuda_frsqrt
(
sqrsum
+
epsilon
);
CHECK_NAN
(
coef
,
"coef"
);
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
pack
[
i
]
*=
coef
*
float
(
rms
[
i
]);
CHECK_NAN
(
rms
[
i
],
"rms.wgt"
);
CHECK_NAN
(
pack
[
i
],
"rms.out"
);
}
#if 1
// rope
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
+=
2
)
{
float2
pack2
=
half22float2
(
half2_t
(
pack
[
i
],
pack
[
i
+
1
]));
CHECK_NAN
(
freq
[
i
].
x
,
"rope.freq"
);
CHECK_NAN
(
freq
[
i
].
y
,
"rope.freq"
);
CHECK_NAN
(
freq
[
i
+
1
].
x
,
"rope.freq"
);
CHECK_NAN
(
freq
[
i
+
1
].
y
,
"rope.freq"
);
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
float
sin
,
cos
;
if
constexpr
(
ROTARY_EMB_NUM_ELEMENTS
==
1
)
{
sin
=
cuda_sin
(
rope
[
i
/
2
]);
cos
=
cuda_cos
(
rope
[
i
/
2
]);
}
if
constexpr
(
ROTARY_EMB_NUM_ELEMENTS
==
2
)
{
sin
=
rope
[
i
];
cos
=
rope
[
i
+
1
];
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
pack
[
i
]
=
half_t
(
pack2
.
x
*
cos
-
pack2
.
y
*
sin
);
pack
[
i
+
1
]
=
half_t
(
pack2
.
x
*
sin
+
pack2
.
y
*
cos
);
CHECK_NAN
(
pack
[
i
],
"rope.out"
);
CHECK_NAN
(
pack
[
i
+
1
],
"rope.out"
);
}
#endif
// mean pool
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
reduce_tmp
[
i
]
+=
pack
[
i
];
}
});
if
(
!
pool_out
)
{
return
;
}
store
<
true
>
(
&
pool
[
warpId
],
reduce_tmp
);
__syncthreads
();
if
(
warpId
<
NUM_POOLS_PER_BLOCK
)
{
const
int
row
=
warpId
*
NUM_WARPS_PER_POOL
;
reduce_tmp
=
load
<
true
>
(
&
pool
[
row
]);
for
(
int
i
=
1
;
i
<
NUM_WARPS_PER_POOL
;
i
++
)
{
pack_t
pack
=
load
<
true
>
(
&
pool
[
row
+
i
]);
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
reduce_tmp
[
j
]
+=
pack
[
j
];
}
}
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
reduce_tmp
[
j
]
/=
PoolSize
;
}
store
(
reinterpret_cast
<
pack_t
*>
(
pool_out
+
warpId
*
N
),
reduce_tmp
);
}
__syncthreads
();
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
bool
is_q
=
bn
<
binfo
.
numBlocksN
/
3
;
const
bool
is_k
=
!
is_q
&&
bn
<
binfo
.
numBlocksN
/
3
*
2
;
assert
(
!
args
.
pool_out
||
args
.
actualM
==
M
);
assert
(
args
.
actualN
==
N
);
if
(
is_q
||
is_k
)
{
apply
(
fpsum
,
args
.
out
+
bm
*
BLOCK_M
*
args
.
actualN
+
bn
*
BLOCK_N
,
M
,
N
,
K
,
args
.
pool_out
?
args
.
pool_out
+
bm
*
BLOCK_M
/
PoolSize
*
N
:
nullptr
,
args
.
rotary_emb
+
bm
*
BLOCK_M
*
(
HEAD_DIM
/
2
*
ROTARY_EMB_NUM_ELEMENTS
),
is_q
?
args
.
rmsnorm_weight_q
:
args
.
rmsnorm_weight_k
,
args
.
epsilon
,
args
.
actualM
-
bm
*
BLOCK_M
);
}
else
{
EpilogueDefault
()(
binfo
,
fpsum
,
M
,
N
,
K
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
out
,
.
actualM
=
args
.
actualM
,
.
actualN
=
args
.
actualN
,
});
}
}
};
struct
EpilogueRMSNormRope
{
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
NUM_HEADS_PER_WARP
=
WARP_N
/
HEAD_DIM
;
static
constexpr
int
WARP_N_TILES_PER_HEAD
=
WARP_N_TILES
/
NUM_HEADS_PER_WARP
;
static
constexpr
int
ROTARY_EMB_NUM_ELEMENTS
=
2
;
using
packed_rotemb_t
=
float4
;
static
constexpr
int
WARP_N_ROTEMB_TILES
=
WARP_N_TILES
/
NUM_HEADS_PER_WARP
*
2
;
using
rotemb_warp
=
std
::
array
<
packed_rotemb_t
,
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
>
;
// 128 regs
struct
Arguments
{
// **packed** [M, HEAD_DIM] float => [M // 16, HEAD_DIM // 8, WARP_SIZE] of packed_rotemb_t
// aka [M // BLOCK_M, NUM_WARPS, WARP_M_TILES, WARP_N_TILES // NUM_HEADS_PER_WARP * 2, WARP_SIZE]
const
packed_rotemb_t
*
rotary_emb
;
const
half_t
*
rmsnorm_weight_q
;
// [HEAD_DIM]
const
half_t
*
rmsnorm_weight_k
;
// [HEAD_DIM]
float
epsilon
;
};
__device__
__forceinline__
static
rotemb_warp
load_rotemb
(
const
packed_rotemb_t
*
ptr_rotemb
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
rotemb_warp
rotemb
;
const
packed_rotemb_t
*
ptrlane
=
&
ptr_rotemb
[
warpId
*
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
i
>
()
{
unrolled_loop
<
WARP_N_ROTEMB_TILES
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
(
i
*
WARP_N_ROTEMB_TILES
+
j
)
*
WARP_SIZE
;
rotemb
[
i
*
WARP_N_ROTEMB_TILES
+
j
]
=
load
(
&
ptrlane
[
offset
]);
});
});
return
rotemb
;
}
__device__
__forceinline__
static
void
load_rmsnorm
(
const
half_t
*
ptr_rmsnorm_weight
,
half_t
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
static
constexpr
int
PACK_SIZE
=
HEAD_DIM
/
WARP_SIZE
;
using
packed_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_t
pack
=
load
(
reinterpret_cast
<
const
packed_t
*>
(
ptr_rmsnorm_weight
+
laneId
*
PACK_SIZE
));
store
<
true
>
(
reinterpret_cast
<
packed_t
*>
(
shmem
+
laneId
*
PACK_SIZE
),
pack
);
}
__device__
__forceinline__
static
packed_fpsum_t
load_rmsnorm_from_shmem
(
half_t
*
shmem
,
int
n
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
// lane 0-15: n*16+0, lane 16-31: n*16+8
uint4
tmp
;
ldmatrix
(
shmem
+
col
,
tmp
);
return
kernels
::
bit_cast
<
packed_fpsum_t
>
(
tmp
);
}
__device__
__forceinline__
static
void
apply
(
fpsum_warp
&
fpsum
,
const
packed_rotemb_t
*
ptr_rotemb
,
const
half_t
*
ptr_rmsnorm_weight
,
float
epsilon
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
half_t
shmem_rmsnorm
[
NUM_WARPS
][
HEAD_DIM
];
load_rmsnorm
(
ptr_rmsnorm_weight
,
&
shmem_rmsnorm
[
warpId
][
0
]);
__syncwarp
();
rotemb_warp
rotemb
=
load_rotemb
(
ptr_rotemb
);
float
rmsnorm_coef
[
NUM_HEADS_PER_WARP
][
WARP_M_TILES
][
2
];
auto
sqr
=
[](
half2_t
val
)
ALWAYSINLINE
{
float2
fval
=
half22float2
(
val
);
return
fval
.
x
*
fval
.
x
+
fval
.
y
*
fval
.
y
;
};
#pragma unroll
for
(
int
head
=
0
;
head
<
NUM_HEADS_PER_WARP
;
head
++
)
{
const
int
n_offset
=
head
*
WARP_N_TILES_PER_HEAD
;
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
float
sqrsum
[
2
]
=
{
0.0
f
,
0.0
f
};
#pragma unroll
for
(
int
n
=
0
;
n
<
WARP_N_TILES_PER_HEAD
;
n
++
)
{
sqrsum
[
0
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
0
]);
sqrsum
[
1
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
1
]);
sqrsum
[
0
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
2
]);
sqrsum
[
1
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
3
]);
}
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
sqrsum
[
0
]
+=
__shfl_xor_sync
(
~
0
,
sqrsum
[
0
],
mask
);
sqrsum
[
1
]
+=
__shfl_xor_sync
(
~
0
,
sqrsum
[
1
],
mask
);
}
rmsnorm_coef
[
head
][
m
][
0
]
=
cuda_frsqrt
(
sqrsum
[
0
]
/
HEAD_DIM
+
epsilon
);
rmsnorm_coef
[
head
][
m
][
1
]
=
cuda_frsqrt
(
sqrsum
[
1
]
/
HEAD_DIM
+
epsilon
);
}
}
#pragma unroll
for
(
int
head
=
0
;
head
<
NUM_HEADS_PER_WARP
;
head
++
)
{
const
int
n_offset
=
head
*
WARP_N_TILES_PER_HEAD
;
#pragma unroll
for
(
int
n
=
0
;
n
<
WARP_N_TILES_PER_HEAD
;
n
++
)
{
packed_f32psum_t
rms
=
packed_fp16_to_fp32
(
load_rmsnorm_from_shmem
(
&
shmem_rmsnorm
[
warpId
][
0
],
n
));
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
packed_f32psum_t
pack
=
packed_fp16_to_fp32
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
]);
pack
.
data
[
0
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
0
];
pack
.
data
[
1
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
1
];
pack
.
data
[
2
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
2
];
pack
.
data
[
3
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
3
];
pack
.
data
[
4
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
4
];
pack
.
data
[
5
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
5
];
pack
.
data
[
6
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
6
];
pack
.
data
[
7
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
7
];
auto
rope
=
[](
float
&
x
,
float
&
y
,
float
sin
,
float
cos
)
ALWAYSINLINE
{
float
ix
=
x
,
iy
=
y
;
x
=
ix
*
cos
-
iy
*
sin
;
y
=
ix
*
sin
+
iy
*
cos
;
};
{
packed_rotemb_t
sincos
=
rotemb
[
m
*
WARP_N_ROTEMB_TILES
+
n
*
2
];
rope
(
pack
.
data
[
0
],
pack
.
data
[
1
],
sincos
.
x
,
sincos
.
y
);
rope
(
pack
.
data
[
2
],
pack
.
data
[
3
],
sincos
.
z
,
sincos
.
w
);
}
{
packed_rotemb_t
sincos
=
rotemb
[
m
*
WARP_N_ROTEMB_TILES
+
n
*
2
+
1
];
rope
(
pack
.
data
[
4
],
pack
.
data
[
5
],
sincos
.
x
,
sincos
.
y
);
rope
(
pack
.
data
[
6
],
pack
.
data
[
7
],
sincos
.
z
,
sincos
.
w
);
}
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
]
=
packed_fp32_to_fp16
(
pack
);
}
}
}
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
bool
is_q
=
bn
<
binfo
.
numBlocksN
/
3
;
const
bool
is_k
=
!
is_q
&&
bn
<
binfo
.
numBlocksN
/
3
*
2
;
if
(
is_q
||
is_k
)
{
apply
(
fpsum
,
args
.
rotary_emb
+
bm
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
*
WARP_SIZE
,
is_q
?
args
.
rmsnorm_weight_q
:
args
.
rmsnorm_weight_k
,
args
.
epsilon
);
}
}
};
struct
EpiloguePackQKV
{
using
attn_half_t
=
half
;
using
attn_half2_t
=
half2
;
using
packed_qkv_t
=
uint4
;
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
INSN_K_QK
=
16
;
static
constexpr
int
INSN_K_PV
=
16
;
struct
Arguments
{
packed_qkv_t
*
out_q
,
*
out_k
,
*
out_v
;
int
actualM
;
// !!! stride in number of packed_qkv_t !!!
int
strideHead_q
;
int
strideHead_k
;
int
strideHead_v
;
};
__device__
__forceinline__
static
attn_half2_t
convert_half2
(
half2_t
input
)
{
if
constexpr
(
std
::
is_same_v
<
half2_t
,
attn_half2_t
>
)
{
return
input
;
}
else
{
float2
fval
=
half22float2
(
input
);
return
float22half2
<
attn_half2_t
>
(
fval
);
}
}
__device__
__forceinline__
static
packed_qkv_t
pack_q
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
}
__device__
__forceinline__
static
packed_qkv_t
pack_k
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
}
__device__
__forceinline__
static
packed_qkv_t
pack_v
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
0
])));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
1
])));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
2
])));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
3
])));
return
output
;
}
__device__
__forceinline__
static
void
mask
(
packed_qkv_t
&
pack
,
uint32_t
maskVal
,
int
m
,
int
maxRows
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
if
(
m
*
INSN_M
+
laneId
/
4
>=
maxRows
)
{
pack
.
x
=
maskVal
;
pack
.
z
=
maskVal
;
}
if
(
m
*
INSN_M
+
laneId
/
4
+
8
>=
maxRows
)
{
pack
.
y
=
maskVal
;
pack
.
w
=
maskVal
;
}
}
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template
<
typename
F
>
__device__
__forceinline__
static
void
apply
(
fpsum_warp
&
fpsum
,
packed_qkv_t
*
ptr_output
,
int
maxRows
,
F
&&
funcPack
,
attn_half2_t
maskVal
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
static_assert
(
HEAD_DIM
==
WARP_N
);
packed_qkv_t
*
ptrlane
=
&
ptr_output
[((
warpId
*
WARP_M_TILES
+
0
)
*
WARP_N_TILES
+
0
)
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
m
>
()
ALWAYSINLINE
{
unrolled_loop
<
WARP_N_TILES
>
([
&
]
<
int
n
>
()
ALWAYSINLINE
{
packed_qkv_t
pack
=
funcPack
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
mask
(
pack
,
kernels
::
bit_cast
<
uint32_t
>
(
maskVal
),
m
,
maxRows
-
warpId
*
WARP_M
);
store
(
&
ptrlane
[(
m
*
WARP_N_TILES
+
n
)
*
WARP_SIZE
],
pack
);
});
});
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
int
numBlocksQ
=
binfo
.
numBlocksN
/
3
;
const
bool
is_q
=
bn
<
numBlocksQ
;
const
bool
is_k
=
!
is_q
&&
bn
<
numBlocksQ
*
2
;
// bn is head_id (assume HEAD_DIM == WARP_N)
int
head_id
,
strideHead
;
if
(
is_q
)
{
head_id
=
bn
;
strideHead
=
args
.
strideHead_q
;
}
else
if
(
is_k
)
{
head_id
=
bn
-
numBlocksQ
;
strideHead
=
args
.
strideHead_k
;
}
else
{
head_id
=
bn
-
numBlocksQ
*
2
;
strideHead
=
args
.
strideHead_v
;
}
int
block_offset
=
head_id
*
strideHead
+
bm
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_N_TILES
*
WARP_SIZE
;
int
maxRows
=
args
.
actualM
-
bm
*
BLOCK_M
;
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
if
(
is_q
)
{
apply
(
fpsum
,
args
.
out_q
+
block_offset
,
maxRows
,
pack_q
,
attn_half2_t
(
0.0
f
,
0.0
f
));
}
else
if
(
is_k
)
{
apply
(
fpsum
,
args
.
out_k
+
block_offset
,
maxRows
,
pack_k
,
attn_half2_t
(
NAN
,
NAN
));
}
else
{
apply
(
fpsum
,
args
.
out_v
+
block_offset
,
maxRows
,
pack_v
,
attn_half2_t
(
0.0
f
,
0.0
f
));
}
}
};
struct
EpilogueLiteLA
{
__device__
__forceinline__
static
packed_f32psum_t
mma_litela
(
packed_fpsum_t
k
,
packed_fpsum_t
v
,
packed_f32psum_t
psum
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
k
.
data
[
i
]
=
movmatrix
(
k
.
data
[
i
]);
v
.
data
[
i
]
=
movmatrix
(
v
.
data
[
i
]);
}
std
::
swap
(
v
.
data
[
1
],
v
.
data
[
2
]);
return
mma_f16xf16_f32
(
v
,
k
,
psum
);
}
static
constexpr
int
LITELA_HEAD_DIM
=
32
;
static
constexpr
int
LITELA_K_TILES
=
LITELA_HEAD_DIM
/
16
;
static
constexpr
int
LITELA_V_TILES
=
LITELA_HEAD_DIM
/
16
;
static
constexpr
int
SHMEM_SIZE
=
NUM_WARPS
*
(
LITELA_HEAD_DIM
+
1
)
*
(
LITELA_HEAD_DIM
+
8
)
*
sizeof
(
float
);
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__
__forceinline__
static
void
apply_litela
(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
float
*
out_vk
,
int
num_blocks_per_batch
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
using
vk_t
=
float
[
NUM_WARPS
][
LITELA_HEAD_DIM
+
1
][
LITELA_HEAD_DIM
+
8
];
extern
__shared__
uint8_t
shmem
[];
vk_t
&
shmem_vk
=
*
reinterpret_cast
<
vk_t
*>
(
shmem
);
static_assert
(
sizeof
(
vk_t
)
==
SHMEM_SIZE
);
static_assert
(
WARP_N
==
BLOCK_N
);
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
int
num_heads
=
binfo
.
numBlocksN
/
3
*
2
*
(
WARP_N
/
(
LITELA_HEAD_DIM
*
2
));
const
int
batch_id
=
binfo
.
bm
/
num_blocks_per_batch
;
for
(
int
head_id
=
0
;
head_id
<
WARP_N
/
(
LITELA_HEAD_DIM
*
2
);
head_id
++
)
{
const
int
global_head_id
=
(
binfo
.
bn
-
binfo
.
numBlocksN
/
3
)
*
(
WARP_N
/
(
LITELA_HEAD_DIM
*
2
))
+
head_id
;
float
*
out_vk_current_head
=
out_vk
+
(
batch_id
*
num_heads
+
global_head_id
)
*
(
LITELA_HEAD_DIM
+
1
)
*
LITELA_HEAD_DIM
;
for
(
int
i
=
laneId
;
i
<
sizeof
(
shmem_vk
)
/
sizeof
(
float
)
/
NUM_WARPS
;
i
+=
WARP_SIZE
)
{
*
((
&
shmem_vk
[
warpId
][
0
][
0
])
+
i
)
=
0
;
}
__syncwarp
();
for
(
int
tile_v
=
0
;
tile_v
<
LITELA_V_TILES
;
tile_v
++
)
{
for
(
int
tile_k
=
0
;
tile_k
<
LITELA_K_TILES
;
tile_k
++
)
{
packed_f32psum_t
attn_sum
=
{
0
};
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
k
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
tile_k
];
packed_fpsum_t
v
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
LITELA_HEAD_DIM
/
16
+
tile_v
];
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
k
.
data
[
j
]
=
__hmax2
(
k
.
data
[
j
],
half2_t
(
0
,
0
));
// relu
}
attn_sum
=
mma_litela
(
k
,
v
,
attn_sum
);
}
const
int
row
=
tile_v
*
16
+
laneId
/
4
;
const
int
col
=
tile_k
*
16
+
laneId
%
4
*
2
;
shmem_vk
[
warpId
][
row
+
0
][
col
+
0
]
=
attn_sum
.
data
[
0
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
1
]
=
attn_sum
.
data
[
1
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
0
]
=
attn_sum
.
data
[
2
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
1
]
=
attn_sum
.
data
[
3
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
8
]
=
attn_sum
.
data
[
4
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
9
]
=
attn_sum
.
data
[
5
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
8
]
=
attn_sum
.
data
[
6
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
9
]
=
attn_sum
.
data
[
7
];
}
}
for
(
int
tile_k
=
0
;
tile_k
<
LITELA_K_TILES
;
tile_k
++
)
{
packed_f32psum_t
attn_sum
=
{
0
};
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
k
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
tile_k
];
packed_fpsum_t
v
=
{};
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
k
.
data
[
j
]
=
__hmax2
(
k
.
data
[
j
],
half2_t
(
0
,
0
));
// relu
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
v
.
data
[
i
]
=
half2_t
(
1
,
1
);
}
// if (laneId < 4) {
// v.data[0] = half2_t(1, 1);
// v.data[2] = half2_t(1, 1);
// }
// if (laneId % 4 == 0) {
// v.data[0] = half2_t(1, 0);
// v.data[1] = half2_t(1, 0);
// }
attn_sum
=
mma_litela
(
k
,
v
,
attn_sum
);
}
const
int
row
=
LITELA_HEAD_DIM
+
laneId
/
4
;
const
int
col
=
tile_k
*
16
+
laneId
%
4
*
2
;
if
(
laneId
<
4
)
{
shmem_vk
[
warpId
][
row
+
0
][
col
+
0
]
=
attn_sum
.
data
[
0
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
1
]
=
attn_sum
.
data
[
1
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
8
]
=
attn_sum
.
data
[
4
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
9
]
=
attn_sum
.
data
[
5
];
}
}
__syncthreads
();
for
(
int
i
=
warpId
;
i
<
LITELA_HEAD_DIM
+
1
;
i
+=
NUM_WARPS
)
{
for
(
int
j
=
laneId
;
j
<
LITELA_HEAD_DIM
;
j
+=
WARP_SIZE
)
{
float
sum
=
0
;
for
(
int
k
=
0
;
k
<
NUM_WARPS
;
k
++
)
{
sum
+=
shmem_vk
[
k
][
i
][
j
];
}
reduce_add
(
&
out_vk_current_head
[
i
*
LITELA_HEAD_DIM
+
j
],
sum
);
}
}
__syncthreads
();
}
}
struct
Arguments
{
half_t
*
out_q
;
float
*
out_vk
;
int
num_blocks_per_batch
;
int
actualM
;
};
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
if
(
bn
<
binfo
.
numBlocksN
/
3
)
{
fpsum
=
apply_act
(
fpsum
,
[](
half_t
x
)
{
return
__hmax
(
x
,
0
);
});
// relu
return
EpilogueDefault
()(
binfo
,
fpsum
,
M
,
N
/
3
,
K
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
out_q
,
.
actualM
=
args
.
actualM
,
.
actualN
=
N
/
3
,
});
}
return
apply_litela
(
binfo
,
fpsum
,
args
.
out_vk
,
args
.
num_blocks_per_batch
);
}
// each thread block mults BlockSize*HEAD_DIM q and (HEAD_DIM+1)*HEAD_DIM vk, in-place writes back to q
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct
vk_mul_q_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
// FIXME FIXME FIXME
__device__
void
operator
()(
half_t
*
q
,
const
float
*
vk
,
float
eps
,
int
num_tokens
)
{
const
int
block_id
=
blockIdx
.
x
;
const
int
head_id
=
blockIdx
.
y
;
const
int
batch_id
=
blockIdx
.
z
;
const
int
num_blocks
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
y
;
const
int
block_size
=
blockDim
.
x
;
bool
pred
=
block_id
*
block_size
+
threadIdx
.
x
<
num_tokens
;
half_t
*
localq
=
&
q
[(((
batch_id
*
num_blocks
+
block_id
)
*
block_size
+
threadIdx
.
x
)
*
num_heads
+
head_id
)
*
LITELA_HEAD_DIM
];
const
float
*
localvk
=
&
vk
[(
batch_id
*
num_heads
+
head_id
)
*
(
LITELA_HEAD_DIM
+
1
)
*
LITELA_HEAD_DIM
];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
using
packed_q
=
std
::
array
<
half_t
,
8
>
;
using
packed_vk
=
std
::
array
<
float
,
4
>
;
half_t
qblock
[
LITELA_HEAD_DIM
];
for
(
int
i
=
0
;
i
<
LITELA_HEAD_DIM
;
i
+=
sizeof
(
packed_q
)
/
sizeof
(
half_t
))
{
if
(
pred
)
{
*
reinterpret_cast
<
packed_q
*>
(
&
qblock
[
i
])
=
load
(
reinterpret_cast
<
const
packed_q
*>
(
&
localq
[
i
]));
}
}
float
outblock
[
LITELA_HEAD_DIM
+
1
];
#pragma unroll
for
(
int
j
=
0
;
j
<
LITELA_HEAD_DIM
+
1
;
j
++
)
{
outblock
[
j
]
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
LITELA_HEAD_DIM
;
i
+=
sizeof
(
packed_vk
)
/
sizeof
(
float
))
{
packed_vk
vkpack
=
load
(
reinterpret_cast
<
const
packed_vk
*>
(
&
localvk
[
j
*
LITELA_HEAD_DIM
+
i
]));
#pragma unroll
for
(
int
k
=
0
;
k
<
vkpack
.
size
();
k
++
)
{
outblock
[
j
]
+=
(
float
)
qblock
[
i
+
k
]
*
vkpack
[
k
];
}
}
}
for
(
int
i
=
0
;
i
<
LITELA_HEAD_DIM
;
i
+=
sizeof
(
packed_q
)
/
sizeof
(
half_t
))
{
packed_q
opack
;
for
(
int
k
=
0
;
k
<
opack
.
size
();
k
++
)
{
opack
[
k
]
=
__fdividef
(
outblock
[
i
+
k
],
outblock
[
LITELA_HEAD_DIM
]
+
eps
);
}
if
(
pred
)
{
store
(
reinterpret_cast
<
packed_q
*>
(
&
localq
[
i
]),
opack
);
}
}
}
};
};
template
<
typename
Epilogue
>
struct
test_epilogue_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
Base
::
template
load_act_to_fpsum
<
false
>
::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
struct
Arguments
{
const
half_t
*
input
;
half_t
*
output
;
// aligned to BLOCK_M and BLOCK_N
int
M
,
N
;
int
actualM
,
actualN
;
typename
Epilogue
::
Arguments
argsEpilogue
;
};
__device__
__forceinline__
void
operator
()(
Arguments
args
)
{
const
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
m_offset
=
bm
*
BLOCK_M
+
warpId
*
WARP_M
;
const
int
n_offset
=
bn
*
BLOCK_N
;
extern
__shared__
uint8_t
shmem
[];
fpsum_warp
fpsum
;
Base
::
template
load_act_to_fpsum
<
false
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
fpsum
,
shmem
+
warpId
*
SHMEM_PER_WARP
);
Epilogue
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
args
.
argsEpilogue
);
EpilogueDefault
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
output
,
.
actualM
=
args
.
actualM
,
.
actualN
=
args
.
actualN
,
});
}
};
};
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_base.cuh
View file @
e05256c8
...
...
@@ -256,6 +256,16 @@ public:
return
results
;
}
__device__
__forceinline__
static
f32psum_warp
packed_fp16_to_fp32
(
fpsum_warp
input
)
{
f32psum_warp
results
;
#pragma unroll
for
(
int
i
=
0
;
i
<
results
.
size
();
i
++
)
{
results
[
i
]
=
packed_fp16_to_fp32
(
input
[
i
]);
}
return
results
;
}
// activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
__device__
__forceinline__
static
void
load_act
(
const
packed_act_t
*
act
,
int
k
,
int
K
,
act_warp
&
out
,
bool
pred
)
{
...
...
@@ -570,6 +580,63 @@ public:
}
};
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
// [WARP_M, WARP_N * 2] when fuse_glu
template
<
bool
fuse_glu
>
struct
load_act_to_fpsum
{
using
matrix_t
=
half_t
[
INSN_M
][
WARP_N
+
8
];
static
constexpr
size_t
SHMEM_SIZE
=
sizeof
(
matrix_t
);
__device__
__forceinline__
void
operator
()(
const
half_t
*
input
,
int
stride
,
int
maxRows
,
int
maxCols
,
fpsum_warp
&
out
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
constexpr
int
PACK_SIZE
=
WARP_N
/
WARP_SIZE
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
using
packed_raw_input
=
std
::
array
<
half2_t
,
PACK_SIZE
>
;
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
#pragma unroll
for
(
int
row
=
0
;
row
<
INSN_M
;
row
++
)
{
packed_input
pack
;
// TODO: numCols not multiples of PACK_SIZE
if
constexpr
(
fuse_glu
)
{
packed_raw_input
raw
;
raw
.
fill
(
half2_t
(
0
,
0
));
bool
pred
=
(
m
*
INSN_M
+
row
)
<
maxRows
&&
laneId
*
PACK_SIZE
*
2
<
maxCols
;
if
(
pred
)
{
raw
=
load
(
reinterpret_cast
<
const
packed_raw_input
*>
(
input
+
(
m
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
*
2
));
}
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
pack
[
j
]
=
raw
[
j
].
x
*
silu
(
raw
[
j
].
y
);
}
}
else
{
pack
.
fill
(
half_t
(
0
));
bool
pred
=
(
m
*
INSN_M
+
row
)
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
pack
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
(
m
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
));
}
}
store
<
true
>
(
reinterpret_cast
<
packed_input
*>
(
&
mat
[
row
][
laneId
*
PACK_SIZE
]),
pack
);
}
__syncwarp
();
for
(
int
n
=
0
;
n
<
WARP_N_TILES
;
n
++
)
{
const
int
row
=
laneId
%
16
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
uint4
tmp
;
ldmatrix
(
&
mat
[
row
][
col
],
tmp
);
*
reinterpret_cast
<
uint4
*>
(
&
out
[
m
*
WARP_N_TILES
+
n
])
=
tmp
;
}
__syncwarp
();
}
}
};
template
<
typename
F
>
__device__
__forceinline__
...
...
@@ -599,7 +666,7 @@ public:
};
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
alignas
(
128
)
uint8_t
shmem
[
NUM_WARPS
][
ceilDiv
(
unpack_fpsum
::
SHMEM_SIZE
,
128
)
*
128
];
...
...
@@ -632,7 +699,7 @@ public:
struct
Arguments
{
size_t
unused
;
};
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
}
};
...
...
@@ -696,7 +763,7 @@ public:
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bn
=
binfo
.
bn
;
if
constexpr
(
USE_BIAS
||
USE_SCALE
)
{
apply_bias
(
...
...
@@ -712,7 +779,7 @@ public:
struct
Arguments
{
size_t
unused
;
};
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
fpsum
=
apply_act
(
fpsum
,
[](
half_t
x
)
{
return
silu
(
x
);
});
}
};
...
...
@@ -722,7 +789,7 @@ public:
using
Arguments
=
std
::
tuple
<
typename
Epilogues
::
Arguments
...
>
;
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
// this function makes intellisense crashes :(
#if __INTELLISENSE__
__trap
();
// should not happen when actually compiling
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
e05256c8
...
...
@@ -358,6 +358,14 @@ static void reduce_add(float *addr, float val) {
asm
volatile
(
"red.relaxed.gpu.global.add.f32 [%0], %1;"
::
"l"
(
addr
),
"f"
(
val
));
}
__device__
__forceinline__
static
void
reduce_add_pred
(
float
*
addr
,
float
val
,
bool
pred
)
{
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred red.relaxed.gpu.global.add.f32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"f"
(
val
));
}
template
<
int
cnt
,
typename
F
>
__device__
__forceinline__
static
void
unrolled_loop
(
F
&&
lambda
)
{
...
...
src/kernels/zgemm/gemm_w4a4.cuh
View file @
e05256c8
#pragma once
#include "gemm_base.cuh"
#include "lora.cuh"
// #include "gemm_w4a4_block.cuh"
namespace
nunchaku
::
kernels
{
...
...
@@ -256,7 +257,7 @@ public:
const
packed_wmscale_t
*
wscales
,
float
alpha
,
// per-tensor scale of weight
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
const
Epilogue
::
Arguments
&
epilogueArgs
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
...
...
@@ -500,64 +501,6 @@ public:
}
}
// loads act of [WARP_M, WARP_N] and stores to fpsum_warp
// [WARP_M, WARP_N * 2] when fuse_glu
template
<
bool
fuse_glu
>
struct
load_act_to_fpsum
{
using
matrix_t
=
half_t
[
INSN_M
][
WARP_N
+
8
];
static
constexpr
size_t
SHMEM_SIZE
=
sizeof
(
matrix_t
);
__device__
__forceinline__
void
operator
()(
const
half_t
*
input
,
int
stride
,
int
maxRows
,
int
maxCols
,
fpsum_warp
&
out
,
void
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
matrix_t
&
mat
=
*
reinterpret_cast
<
matrix_t
*>
(
shmem
);
constexpr
int
PACK_SIZE
=
WARP_N
/
WARP_SIZE
;
using
packed_input
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
using
packed_raw_input
=
std
::
array
<
half2_t
,
PACK_SIZE
>
;
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
#pragma unroll
for
(
int
row
=
0
;
row
<
INSN_M
;
row
++
)
{
packed_input
pack
;
// TODO: numCols not multiples of PACK_SIZE
if
constexpr
(
fuse_glu
)
{
packed_raw_input
raw
;
raw
.
fill
(
half2_t
(
0
,
0
));
bool
pred
=
(
m
*
INSN_M
+
row
)
<
maxRows
&&
laneId
*
PACK_SIZE
*
2
<
maxCols
;
if
(
pred
)
{
raw
=
load
(
reinterpret_cast
<
const
packed_raw_input
*>
(
input
+
(
m
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
*
2
));
}
#pragma unroll
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
pack
[
j
]
=
raw
[
j
].
x
*
silu
(
raw
[
j
].
y
);
}
}
else
{
pack
.
fill
(
half_t
(
0
));
bool
pred
=
(
m
*
INSN_M
+
row
)
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
pack
=
load
(
reinterpret_cast
<
const
packed_input
*>
(
input
+
(
m
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
));
}
}
store
<
true
>
(
reinterpret_cast
<
packed_input
*>
(
&
mat
[
row
][
laneId
*
PACK_SIZE
]),
pack
);
}
__syncwarp
();
for
(
int
n
=
0
;
n
<
WARP_N_TILES
;
n
++
)
{
const
int
row
=
laneId
%
16
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
uint4
tmp
;
ldmatrix
(
&
mat
[
row
][
col
],
tmp
);
*
reinterpret_cast
<
uint4
*>
(
&
out
[
m
*
WARP_N_TILES
+
n
])
=
tmp
;
}
__syncwarp
();
}
}
};
/**
* each warp quantizes a INSN_M * INSN_K (16 * 64) matrix
...
...
@@ -883,7 +826,7 @@ public:
// const packed_wscale_t *bias_ptr,
// half_t *out,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
const
Epilogue
::
Arguments
&
epilogueArgs
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
...
...
@@ -1057,7 +1000,7 @@ public:
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
...
...
@@ -1077,1045 +1020,6 @@ public:
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
template
<
int
rank
=
32
>
struct
Lora
{
static_assert
(
rank
%
16
==
0
);
static
constexpr
int
LORA_RANK
=
rank
;
static
constexpr
int
LORA_M_TILES
=
WARP_M
/
16
;
static
constexpr
int
LORA_R_TILES
=
LORA_RANK
/
16
;
static
constexpr
int
LORA_N_TILES
=
WARP_N
/
16
;
static_assert
(
LORA_M_TILES
==
WARP_M_TILES
);
static_assert
(
LORA_N_TILES
==
WARP_N_TILES
);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using
lora_act_warp
=
std
::
array
<
packed_f32psum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_act16_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_wgt_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_N_TILES
*
LORA_R_TILES
>
;
using
scale_t
=
std
::
array
<
float
,
LORA_R_TILES
>
;
// lora_wgt: [N / 16, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
__device__
__forceinline__
static
lora_wgt_warp
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
packed_fpsum_t
*
ptr_lane
=
ptr
+
laneId
;
lora_wgt_warp
result
;
#if 0
#pragma unroll
for (int n = 0; n < LORA_N_TILES; n++) {
#pragma unroll
for (int r = 0; r < LORA_R_TILES; r++) {
result[n * LORA_R_TILES + r] = load(ptr_lane + (n * LORA_R_TILES + r) * WARP_SIZE);
}
}
#else
unrolled_loop
<
LORA_N_TILES
>
([
&
]
<
int
n
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
()
{
constexpr
int
offset
=
(
n
*
LORA_R_TILES
+
r
)
*
WARP_SIZE
;
result
[
n
*
LORA_R_TILES
+
r
]
=
load
(
ptr_lane
+
offset
);
});
});
#endif
return
result
;
}
// lora_act: [M / BLOCK_M, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__
__forceinline__
static
lora_act16_warp
load_lora_act
(
const
float
*
ptr
,
scale_t
scales
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
float
*
ptrlane
=
ptr
+
laneId
;
lora_act16_warp
result
;
#if 0
#pragma unroll
for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
packed_f32psum_t tmp;
#pragma unroll
for (int j = 0; j < 8; j++) {
const int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
tmp.data[j] = ptrlane[offset];
// tmp.data[j] = ptr[i * 8 * WARP_SIZE + j * WARP_SIZE + laneId];
}
CHECK_NAN(tmp, "load_lora_act.tmp");
result[i] = packed_fp32_to_fp16(tmp);
}
#else
unrolled_loop
<
LORA_M_TILES
>
([
&
]
<
int
m
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
constexpr
int
i
=
m
*
LORA_R_TILES
+
r
;
packed_f32psum_t
tmp
;
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
tmp
.
data
[
j
]
=
ptrlane
[
offset
]
*
scales
[
r
];
});
CHECK_NAN
(
tmp
,
"load_lora_act.tmp"
);
result
[
i
]
=
packed_fp32_to_fp16
(
tmp
);
});
});
#endif
return
result
;
}
// no vector reduction in sm_89 :(
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
lora_act_warp
val
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
float
*
ptrlane
=
ptr
+
laneId
;
// #pragma unroll
// for (int i = 0; i < LORA_M_TILES * LORA_R_TILES; i++) {
// #pragma unroll
// for (int j = 0; j < 8; j++) {
// int offset = i * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[i].data[j]);
// }
// }
unrolled_loop
<
LORA_M_TILES
*
LORA_R_TILES
>
([
&
]
<
int
i
>
()
{
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
reduce_add
(
&
ptrlane
[
offset
],
val
[
i
].
data
[
j
]);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct
EpilogueLoraUp
{
struct
Arguments
{
const
float
*
lora_act
;
const
packed_fpsum_t
*
lora_wgt_up
;
scale_t
scales
;
};
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
scales
,
const
BlockInfo
binfo
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
if
constexpr
(
rank
>
0
)
{
lora_act16_warp
lora_act
=
load_lora_act
(
act
+
warpId
*
(
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
scales
);
lora_wgt_warp
lora_wgt
=
load_lora_wgt
(
wgt
);
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
packed_f32psum_t
psum
=
packed_fp16_to_fp32
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
psum
=
mma_f16xf16_f32
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
psum
);
}
fpsum
[
m
*
WARP_N_TILES
+
n
]
=
packed_fp32_to_fp16
(
psum
);
}
}
}
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
CHECK_NAN
(
fpsum
,
"fpsum"
);
if
constexpr
(
rank
==
0
)
{
return
;
}
apply_lora_up
(
fpsum
,
M
,
N
,
K
,
args
.
lora_act
+
bm
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
LORA_R_TILES
*
WARP_SIZE
,
args
.
scales
,
binfo
// for debug
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
}
};
struct
EpilogueLoraDown
{
struct
Arguments
{
const
packed_fpsum_t
*
lora_wgt_down
;
float
*
lora_act
;
};
__device__
__forceinline__
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
float
*
act
,
const
packed_fpsum_t
*
wgt
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
if
constexpr
(
rank
>
0
)
{
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
lora_wgt_warp
lora_wgt
=
load_lora_wgt
(
wgt
);
// clock_t dummy = 0;
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
CHECK_NAN
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
"apply_lora_down.fpsum"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"apply_lora_down.lora_wgt"
);
psum
=
mma_f16xf16_f32
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
psum
);
CHECK_NAN
(
psum
,
"apply_lora_down.psum"
);
}
}
// reduce_lora_act(act + warpId * (LORA_M_TILES * LORA_R_TILES * 8 * WARP_SIZE), lora_act, m);
// if (alwaysfalse) {
// dummy = clock();
// }
}
reduce_lora_act
(
act
+
warpId
*
(
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
lora_act
);
// unused_var(dummy, alwaysfalse);
}
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
if
constexpr
(
rank
==
0
)
{
return
;
}
apply_lora_down
(
fpsum
,
M
,
N
,
K
,
args
.
lora_act
+
bm
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
LORA_R_TILES
*
WARP_SIZE
);
}
};
template
<
bool
fuse_glu
,
bool
use_fp4
>
struct
quantize_w4a4_fuse_lora_kernel
{
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
struct
Arguments
{
const
half_t
*
input
;
const
packed_wscale_t
*
smooth_factor
;
packed_act_t
*
output
;
oscales_t
*
oscales
;
const
packed_fpsum_t
*
lora_wgt_down
;
float
*
lora_act
;
// aligned to BLOCK_M and BLOCK_N
int
M
,
N
;
// N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int
actualM
,
actualN
;
};
__device__
__forceinline__
void
operator
()(
Arguments
args
)
{
const
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
m_offset
=
bm
*
BLOCK_M
+
warpId
*
WARP_M
;
const
int
n_offset
=
bn
*
BLOCK_N
*
(
fuse_glu
?
2
:
1
);
extern
__shared__
uint8_t
shmem
[];
fpsum_warp
fpsum
;
load_act_to_fpsum
<
fuse_glu
>
()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
fpsum
,
shmem
+
warpId
*
SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
EpilogueLoraDown
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
args
.
lora_wgt_down
,
.
lora_act
=
args
.
lora_act
,
});
EpilogueQuantize
<
false
,
false
,
use_fp4
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
,
use_fp4
>::
Arguments
{
.
qout
=
args
.
output
,
.
oscales
=
args
.
oscales
,
.
shift_value
=
0
,
.
smooth_factor
=
args
.
smooth_factor
});
}
};
};
struct
EpilogueGelu
{
struct
Arguments
{
size_t
unused
;
};
// static constexpr float SHIFT_VALUE = 0.171875f;
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
half2_t
&
data
=
fpsum
[
i
*
WARP_N_TILES
+
j
].
data
[
k
];
data
=
gelu_half2
(
data
);
// data = __hadd2(data, half2_t(SHIFT_VALUE, SHIFT_VALUE));
}
}
}
}
};
// template<int PoolSize = 128>
struct
EpilogueQKVProj
{
struct
Arguments
{
half_t
*
out
;
int
actualM
,
actualN
;
half_t
*
pool_out
;
// [M / PoolSize, N]
const
float
*
rotary_emb
;
// [M, HEAD_DIM / 2, ROTARY_EMB_NUM_ELEMENTS]
const
half_t
*
rmsnorm_weight_q
;
// [HEAD_DIM]
const
half_t
*
rmsnorm_weight_k
;
// [HEAD_DIM]
float
epsilon
;
};
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
NUM_HEADS_PER_WARP
=
WARP_N
/
HEAD_DIM
;
static
constexpr
int
PoolSize
=
128
;
static
constexpr
int
NUM_WARPS_PER_POOL
=
PoolSize
/
WARP_M
;
static
constexpr
int
NUM_POOLS_PER_BLOCK
=
BLOCK_M
/
PoolSize
;
static
constexpr
int
ROTARY_EMB_NUM_ELEMENTS
=
2
;
// 1 for theta, 2 for {sin, cos} pair
__device__
__forceinline__
static
void
apply
(
fpsum_warp
fpsum
,
half_t
*
out
,
int
M
,
int
N
,
int
K
,
half_t
*
pool_out
,
const
float
*
rotary_emb
,
const
half_t
*
rmsnorm_weight
,
float
epsilon
,
int
maxRows
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
alignas
(
128
)
uint8_t
shmem
[
NUM_WARPS
][
ceilDiv
(
unpack_fpsum
::
SHMEM_SIZE
,
128
)
*
128
];
constexpr
int
PACK_SIZE
=
unpack_fpsum
::
PACK_SIZE
;
using
pack_t
=
unpack_fpsum
::
pack_t
;
using
pack_rope_t
=
std
::
array
<
float
,
PACK_SIZE
/
2
*
ROTARY_EMB_NUM_ELEMENTS
>
;
constexpr
int
LANES_PER_HEAD
=
HEAD_DIM
/
PACK_SIZE
;
pack_t
reduce_tmp
;
__shared__
alignas
(
128
)
pack_t
pool
[
NUM_WARPS
];
// load rmsnorm scales
pack_t
rms
;
if
(
laneId
<
LANES_PER_HEAD
)
{
rms
=
load
(
reinterpret_cast
<
const
pack_t
*>
(
&
rmsnorm_weight
[
laneId
*
PACK_SIZE
]));
}
if
constexpr
(
LANES_PER_HEAD
<
WARP_SIZE
)
{
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
rms
[
i
]
=
__shfl_sync
(
~
0
,
rms
[
i
],
laneId
%
LANES_PER_HEAD
);
}
}
const
float
*
rotary_emb_base_addr
=
&
rotary_emb
[(
warpId
*
WARP_M
)
*
HEAD_DIM
/
2
*
ROTARY_EMB_NUM_ELEMENTS
+
laneId
*
PACK_SIZE
/
2
*
ROTARY_EMB_NUM_ELEMENTS
];
CHECK_NAN
(
fpsum
,
"fpsum"
);
unpack_fpsum
()(
fpsum
,
out
+
warpId
*
WARP_M
*
N
,
N
,
maxRows
-
warpId
*
WARP_M
,
INT_MAX
,
shmem
[
warpId
],
[
&
](
int
rowId
,
pack_t
&
pack
)
ALWAYSINLINE
{
// load rope
pack_rope_t
rope
;
if
(
laneId
<
LANES_PER_HEAD
)
{
// freq = load(reinterpret_cast<pack_freq_t *>(&freqs_cis[(warpId * WARP_M + rowId) * HEAD_DIM * 2 + laneId * PACK_SIZE * 2]));
rope
=
load
(
reinterpret_cast
<
const
pack_rope_t
*>
(
&
rotary_emb_base_addr
[
rowId
*
HEAD_DIM
/
2
*
ROTARY_EMB_NUM_ELEMENTS
]));
}
if
constexpr
(
LANES_PER_HEAD
<
WARP_SIZE
)
{
for
(
int
i
=
0
;
i
<
rope
.
size
();
i
++
)
{
rope
[
i
]
=
__shfl_sync
(
~
0
,
rope
[
i
],
laneId
%
LANES_PER_HEAD
);
}
}
// rmsnorm
float
sqrsum
=
0.0
f
;
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
sqrsum
+=
float
(
pack
[
i
])
*
float
(
pack
[
i
]);
CHECK_NAN
(
sqrsum
,
"sqrsum"
);
}
#pragma unroll
for
(
int
mask
=
LANES_PER_HEAD
/
2
;
mask
>
0
;
mask
/=
2
)
{
sqrsum
+=
__shfl_xor_sync
(
~
0
,
sqrsum
,
mask
);
}
sqrsum
/=
HEAD_DIM
;
float
coef
=
cuda_frsqrt
(
sqrsum
+
epsilon
);
CHECK_NAN
(
coef
,
"coef"
);
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
pack
[
i
]
*=
coef
*
float
(
rms
[
i
]);
CHECK_NAN
(
rms
[
i
],
"rms.wgt"
);
CHECK_NAN
(
pack
[
i
],
"rms.out"
);
}
#if 1
// rope
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
+=
2
)
{
float2
pack2
=
half22float2
(
half2_t
(
pack
[
i
],
pack
[
i
+
1
]));
CHECK_NAN
(
freq
[
i
].
x
,
"rope.freq"
);
CHECK_NAN
(
freq
[
i
].
y
,
"rope.freq"
);
CHECK_NAN
(
freq
[
i
+
1
].
x
,
"rope.freq"
);
CHECK_NAN
(
freq
[
i
+
1
].
y
,
"rope.freq"
);
// half2_t tmp = __hmul2(freq[i], pack2);
// tmp = __hfma2(freq[i+1], pack2, tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
// printf("block.x=%d block.y=%d warpId=%d rowId=%d (%d) freqs = %f %f %f %f\n",
// blockIdx.x, blockIdx.y, warpId, rowId,
// blockIdx.x * BLOCK_M + warpId * WARP_M + rowId,
// (float)freq[i].x, (float)freq[i].y, (float)freq[i+1].x, (float)freq[i+1].y
// );
// __trap();
// half2_t tmp = __hmul2(half2_t(pack2.x, pack2.x), freq[i]);
// tmp = __hfma2(half2_t(pack2.y, pack2.y), freq[i+1], tmp);
// pack[i] = tmp.x;
// pack[i+1] = tmp.y;
float
sin
,
cos
;
if
constexpr
(
ROTARY_EMB_NUM_ELEMENTS
==
1
)
{
sin
=
cuda_sin
(
rope
[
i
/
2
]);
cos
=
cuda_cos
(
rope
[
i
/
2
]);
}
if
constexpr
(
ROTARY_EMB_NUM_ELEMENTS
==
2
)
{
sin
=
rope
[
i
];
cos
=
rope
[
i
+
1
];
}
// pack[i] = pack2.x * freq[i].x + pack2.y * freq[i].y;
// pack[i+1] = pack2.x * freq[i+1].x + pack2.y * freq[i+1].y;
pack
[
i
]
=
half_t
(
pack2
.
x
*
cos
-
pack2
.
y
*
sin
);
pack
[
i
+
1
]
=
half_t
(
pack2
.
x
*
sin
+
pack2
.
y
*
cos
);
CHECK_NAN
(
pack
[
i
],
"rope.out"
);
CHECK_NAN
(
pack
[
i
+
1
],
"rope.out"
);
}
#endif
// mean pool
for
(
int
i
=
0
;
i
<
PACK_SIZE
;
i
++
)
{
reduce_tmp
[
i
]
+=
pack
[
i
];
}
});
if
(
!
pool_out
)
{
return
;
}
store
<
true
>
(
&
pool
[
warpId
],
reduce_tmp
);
__syncthreads
();
if
(
warpId
<
NUM_POOLS_PER_BLOCK
)
{
const
int
row
=
warpId
*
NUM_WARPS_PER_POOL
;
reduce_tmp
=
load
<
true
>
(
&
pool
[
row
]);
for
(
int
i
=
1
;
i
<
NUM_WARPS_PER_POOL
;
i
++
)
{
pack_t
pack
=
load
<
true
>
(
&
pool
[
row
+
i
]);
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
reduce_tmp
[
j
]
+=
pack
[
j
];
}
}
for
(
int
j
=
0
;
j
<
PACK_SIZE
;
j
++
)
{
reduce_tmp
[
j
]
/=
PoolSize
;
}
store
(
reinterpret_cast
<
pack_t
*>
(
pool_out
+
warpId
*
N
),
reduce_tmp
);
}
__syncthreads
();
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
bool
is_q
=
bn
<
binfo
.
numBlocksN
/
3
;
const
bool
is_k
=
!
is_q
&&
bn
<
binfo
.
numBlocksN
/
3
*
2
;
assert
(
!
args
.
pool_out
||
args
.
actualM
==
M
);
assert
(
args
.
actualN
==
N
);
if
(
is_q
||
is_k
)
{
apply
(
fpsum
,
args
.
out
+
bm
*
BLOCK_M
*
args
.
actualN
+
bn
*
BLOCK_N
,
M
,
N
,
K
,
args
.
pool_out
?
args
.
pool_out
+
bm
*
BLOCK_M
/
PoolSize
*
N
:
nullptr
,
args
.
rotary_emb
+
bm
*
BLOCK_M
*
(
HEAD_DIM
/
2
*
ROTARY_EMB_NUM_ELEMENTS
),
is_q
?
args
.
rmsnorm_weight_q
:
args
.
rmsnorm_weight_k
,
args
.
epsilon
,
args
.
actualM
-
bm
*
BLOCK_M
);
}
else
{
EpilogueDefault
()(
binfo
,
fpsum
,
M
,
N
,
K
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
out
,
.
actualM
=
args
.
actualM
,
.
actualN
=
args
.
actualN
,
});
}
}
};
struct
EpilogueRMSNormRope
{
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
NUM_HEADS_PER_WARP
=
WARP_N
/
HEAD_DIM
;
static
constexpr
int
WARP_N_TILES_PER_HEAD
=
WARP_N_TILES
/
NUM_HEADS_PER_WARP
;
static
constexpr
int
ROTARY_EMB_NUM_ELEMENTS
=
2
;
using
packed_rotemb_t
=
float4
;
static
constexpr
int
WARP_N_ROTEMB_TILES
=
WARP_N_TILES
/
NUM_HEADS_PER_WARP
*
2
;
using
rotemb_warp
=
std
::
array
<
packed_rotemb_t
,
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
>
;
// 128 regs
struct
Arguments
{
// **packed** [M, HEAD_DIM] float => [M // 16, HEAD_DIM // 8, WARP_SIZE] of packed_rotemb_t
// aka [M // BLOCK_M, NUM_WARPS, WARP_M_TILES, WARP_N_TILES // NUM_HEADS_PER_WARP * 2, WARP_SIZE]
const
packed_rotemb_t
*
rotary_emb
;
const
half_t
*
rmsnorm_weight_q
;
// [HEAD_DIM]
const
half_t
*
rmsnorm_weight_k
;
// [HEAD_DIM]
float
epsilon
;
};
__device__
__forceinline__
static
rotemb_warp
load_rotemb
(
const
packed_rotemb_t
*
ptr_rotemb
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
rotemb_warp
rotemb
;
const
packed_rotemb_t
*
ptrlane
=
&
ptr_rotemb
[
warpId
*
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
i
>
()
{
unrolled_loop
<
WARP_N_ROTEMB_TILES
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
(
i
*
WARP_N_ROTEMB_TILES
+
j
)
*
WARP_SIZE
;
rotemb
[
i
*
WARP_N_ROTEMB_TILES
+
j
]
=
load
(
&
ptrlane
[
offset
]);
});
});
return
rotemb
;
}
__device__
__forceinline__
static
void
load_rmsnorm
(
const
half_t
*
ptr_rmsnorm_weight
,
half_t
*
shmem
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
static
constexpr
int
PACK_SIZE
=
HEAD_DIM
/
WARP_SIZE
;
using
packed_t
=
std
::
array
<
half_t
,
PACK_SIZE
>
;
packed_t
pack
=
load
(
reinterpret_cast
<
const
packed_t
*>
(
ptr_rmsnorm_weight
+
laneId
*
PACK_SIZE
));
store
<
true
>
(
reinterpret_cast
<
packed_t
*>
(
shmem
+
laneId
*
PACK_SIZE
),
pack
);
}
__device__
__forceinline__
static
packed_fpsum_t
load_rmsnorm_from_shmem
(
half_t
*
shmem
,
int
n
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
col
=
n
*
INSN_N
+
laneId
/
16
*
8
;
// lane 0-15: n*16+0, lane 16-31: n*16+8
uint4
tmp
;
ldmatrix
(
shmem
+
col
,
tmp
);
return
kernels
::
bit_cast
<
packed_fpsum_t
>
(
tmp
);
}
__device__
__forceinline__
static
void
apply
(
fpsum_warp
&
fpsum
,
const
packed_rotemb_t
*
ptr_rotemb
,
const
half_t
*
ptr_rmsnorm_weight
,
float
epsilon
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
half_t
shmem_rmsnorm
[
NUM_WARPS
][
HEAD_DIM
];
load_rmsnorm
(
ptr_rmsnorm_weight
,
&
shmem_rmsnorm
[
warpId
][
0
]);
__syncwarp
();
rotemb_warp
rotemb
=
load_rotemb
(
ptr_rotemb
);
float
rmsnorm_coef
[
NUM_HEADS_PER_WARP
][
WARP_M_TILES
][
2
];
auto
sqr
=
[](
half2_t
val
)
ALWAYSINLINE
{
float2
fval
=
half22float2
(
val
);
return
fval
.
x
*
fval
.
x
+
fval
.
y
*
fval
.
y
;
};
#pragma unroll
for
(
int
head
=
0
;
head
<
NUM_HEADS_PER_WARP
;
head
++
)
{
const
int
n_offset
=
head
*
WARP_N_TILES_PER_HEAD
;
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
float
sqrsum
[
2
]
=
{
0.0
f
,
0.0
f
};
#pragma unroll
for
(
int
n
=
0
;
n
<
WARP_N_TILES_PER_HEAD
;
n
++
)
{
sqrsum
[
0
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
0
]);
sqrsum
[
1
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
1
]);
sqrsum
[
0
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
2
]);
sqrsum
[
1
]
+=
sqr
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
].
data
[
3
]);
}
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
sqrsum
[
0
]
+=
__shfl_xor_sync
(
~
0
,
sqrsum
[
0
],
mask
);
sqrsum
[
1
]
+=
__shfl_xor_sync
(
~
0
,
sqrsum
[
1
],
mask
);
}
rmsnorm_coef
[
head
][
m
][
0
]
=
cuda_frsqrt
(
sqrsum
[
0
]
/
HEAD_DIM
+
epsilon
);
rmsnorm_coef
[
head
][
m
][
1
]
=
cuda_frsqrt
(
sqrsum
[
1
]
/
HEAD_DIM
+
epsilon
);
}
}
#pragma unroll
for
(
int
head
=
0
;
head
<
NUM_HEADS_PER_WARP
;
head
++
)
{
const
int
n_offset
=
head
*
WARP_N_TILES_PER_HEAD
;
#pragma unroll
for
(
int
n
=
0
;
n
<
WARP_N_TILES_PER_HEAD
;
n
++
)
{
packed_f32psum_t
rms
=
packed_fp16_to_fp32
(
load_rmsnorm_from_shmem
(
&
shmem_rmsnorm
[
warpId
][
0
],
n
));
#pragma unroll
for
(
int
m
=
0
;
m
<
WARP_M_TILES
;
m
++
)
{
packed_f32psum_t
pack
=
packed_fp16_to_fp32
(
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
]);
pack
.
data
[
0
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
0
];
pack
.
data
[
1
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
1
];
pack
.
data
[
2
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
2
];
pack
.
data
[
3
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
3
];
pack
.
data
[
4
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
4
];
pack
.
data
[
5
]
*=
rmsnorm_coef
[
head
][
m
][
0
]
*
rms
.
data
[
5
];
pack
.
data
[
6
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
6
];
pack
.
data
[
7
]
*=
rmsnorm_coef
[
head
][
m
][
1
]
*
rms
.
data
[
7
];
auto
rope
=
[](
float
&
x
,
float
&
y
,
float
sin
,
float
cos
)
ALWAYSINLINE
{
float
ix
=
x
,
iy
=
y
;
x
=
ix
*
cos
-
iy
*
sin
;
y
=
ix
*
sin
+
iy
*
cos
;
};
{
packed_rotemb_t
sincos
=
rotemb
[
m
*
WARP_N_ROTEMB_TILES
+
n
*
2
];
rope
(
pack
.
data
[
0
],
pack
.
data
[
1
],
sincos
.
x
,
sincos
.
y
);
rope
(
pack
.
data
[
2
],
pack
.
data
[
3
],
sincos
.
z
,
sincos
.
w
);
}
{
packed_rotemb_t
sincos
=
rotemb
[
m
*
WARP_N_ROTEMB_TILES
+
n
*
2
+
1
];
rope
(
pack
.
data
[
4
],
pack
.
data
[
5
],
sincos
.
x
,
sincos
.
y
);
rope
(
pack
.
data
[
6
],
pack
.
data
[
7
],
sincos
.
z
,
sincos
.
w
);
}
fpsum
[
m
*
WARP_N_TILES
+
n
+
n_offset
]
=
packed_fp32_to_fp16
(
pack
);
}
}
}
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
bool
is_q
=
bn
<
binfo
.
numBlocksN
/
3
;
const
bool
is_k
=
!
is_q
&&
bn
<
binfo
.
numBlocksN
/
3
*
2
;
if
(
is_q
||
is_k
)
{
apply
(
fpsum
,
args
.
rotary_emb
+
bm
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_N_ROTEMB_TILES
*
WARP_SIZE
,
is_q
?
args
.
rmsnorm_weight_q
:
args
.
rmsnorm_weight_k
,
args
.
epsilon
);
}
}
};
struct
EpiloguePackQKV
{
using
attn_half_t
=
half
;
using
attn_half2_t
=
half2
;
using
packed_qkv_t
=
uint4
;
static
constexpr
int
HEAD_DIM
=
128
;
static
constexpr
int
INSN_K_QK
=
16
;
static
constexpr
int
INSN_K_PV
=
16
;
struct
Arguments
{
packed_qkv_t
*
out_q
,
*
out_k
,
*
out_v
;
int
actualM
;
// !!! stride in number of packed_qkv_t !!!
int
strideHead_q
;
int
strideHead_k
;
int
strideHead_v
;
};
__device__
__forceinline__
static
attn_half2_t
convert_half2
(
half2_t
input
)
{
if
constexpr
(
std
::
is_same_v
<
half2_t
,
attn_half2_t
>
)
{
return
input
;
}
else
{
float2
fval
=
half22float2
(
input
);
return
float22half2
<
attn_half2_t
>
(
fval
);
}
}
__device__
__forceinline__
static
packed_qkv_t
pack_q
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
}
__device__
__forceinline__
static
packed_qkv_t
pack_k
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
0
]));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
2
]));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
1
]));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
input
.
data
[
3
]));
return
output
;
}
__device__
__forceinline__
static
packed_qkv_t
pack_v
(
packed_fpsum_t
input
)
{
packed_qkv_t
output
;
output
.
x
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
0
])));
output
.
y
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
1
])));
output
.
z
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
2
])));
output
.
w
=
kernels
::
bit_cast
<
int
>
(
convert_half2
(
movmatrix
(
input
.
data
[
3
])));
return
output
;
}
__device__
__forceinline__
static
void
mask
(
packed_qkv_t
&
pack
,
uint32_t
maskVal
,
int
m
,
int
maxRows
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
if
(
m
*
INSN_M
+
laneId
/
4
>=
maxRows
)
{
pack
.
x
=
maskVal
;
pack
.
z
=
maskVal
;
}
if
(
m
*
INSN_M
+
laneId
/
4
+
8
>=
maxRows
)
{
pack
.
y
=
maskVal
;
pack
.
w
=
maskVal
;
}
}
// qkv: [batch, head, bm, NUM_WARPS, WARP_M_TILES, WARP_N_TILES, WARP_SIZE] of packed_qkv_t
template
<
typename
F
>
__device__
__forceinline__
static
void
apply
(
fpsum_warp
&
fpsum
,
packed_qkv_t
*
ptr_output
,
int
maxRows
,
F
&&
funcPack
,
attn_half2_t
maskVal
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
static_assert
(
HEAD_DIM
==
WARP_N
);
packed_qkv_t
*
ptrlane
=
&
ptr_output
[((
warpId
*
WARP_M_TILES
+
0
)
*
WARP_N_TILES
+
0
)
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
WARP_M_TILES
>
([
&
]
<
int
m
>
()
ALWAYSINLINE
{
unrolled_loop
<
WARP_N_TILES
>
([
&
]
<
int
n
>
()
ALWAYSINLINE
{
packed_qkv_t
pack
=
funcPack
(
fpsum
[
m
*
WARP_N_TILES
+
n
]);
mask
(
pack
,
kernels
::
bit_cast
<
uint32_t
>
(
maskVal
),
m
,
maxRows
-
warpId
*
WARP_M
);
store
(
&
ptrlane
[(
m
*
WARP_N_TILES
+
n
)
*
WARP_SIZE
],
pack
);
});
});
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
int
numBlocksQ
=
binfo
.
numBlocksN
/
3
;
const
bool
is_q
=
bn
<
numBlocksQ
;
const
bool
is_k
=
!
is_q
&&
bn
<
numBlocksQ
*
2
;
// bn is head_id (assume HEAD_DIM == WARP_N)
int
head_id
,
strideHead
;
if
(
is_q
)
{
head_id
=
bn
;
strideHead
=
args
.
strideHead_q
;
}
else
if
(
is_k
)
{
head_id
=
bn
-
numBlocksQ
;
strideHead
=
args
.
strideHead_k
;
}
else
{
head_id
=
bn
-
numBlocksQ
*
2
;
strideHead
=
args
.
strideHead_v
;
}
int
block_offset
=
head_id
*
strideHead
+
bm
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_N_TILES
*
WARP_SIZE
;
int
maxRows
=
args
.
actualM
-
bm
*
BLOCK_M
;
// static constexpr float neginf = -std::numeric_limits<float>::infinity();
if
(
is_q
)
{
apply
(
fpsum
,
args
.
out_q
+
block_offset
,
maxRows
,
pack_q
,
attn_half2_t
(
0.0
f
,
0.0
f
));
}
else
if
(
is_k
)
{
apply
(
fpsum
,
args
.
out_k
+
block_offset
,
maxRows
,
pack_k
,
attn_half2_t
(
NAN
,
NAN
));
}
else
{
apply
(
fpsum
,
args
.
out_v
+
block_offset
,
maxRows
,
pack_v
,
attn_half2_t
(
0.0
f
,
0.0
f
));
}
}
};
struct
EpilogueLiteLA
{
__device__
__forceinline__
static
packed_f32psum_t
mma_litela
(
packed_fpsum_t
k
,
packed_fpsum_t
v
,
packed_f32psum_t
psum
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
k
.
data
[
i
]
=
movmatrix
(
k
.
data
[
i
]);
v
.
data
[
i
]
=
movmatrix
(
v
.
data
[
i
]);
}
std
::
swap
(
v
.
data
[
1
],
v
.
data
[
2
]);
return
mma_f16xf16_f32
(
v
,
k
,
psum
);
}
static
constexpr
int
LITELA_HEAD_DIM
=
32
;
static
constexpr
int
LITELA_K_TILES
=
LITELA_HEAD_DIM
/
16
;
static
constexpr
int
LITELA_V_TILES
=
LITELA_HEAD_DIM
/
16
;
static
constexpr
int
SHMEM_SIZE
=
NUM_WARPS
*
(
LITELA_HEAD_DIM
+
1
)
*
(
LITELA_HEAD_DIM
+
8
)
*
sizeof
(
float
);
// out_vk: [batch_size, num_heads, head_dim + 1, head_dim]
__device__
__forceinline__
static
void
apply_litela
(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
float
*
out_vk
,
int
num_blocks_per_batch
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
using
vk_t
=
float
[
NUM_WARPS
][
LITELA_HEAD_DIM
+
1
][
LITELA_HEAD_DIM
+
8
];
extern
__shared__
uint8_t
shmem
[];
vk_t
&
shmem_vk
=
*
reinterpret_cast
<
vk_t
*>
(
shmem
);
static_assert
(
sizeof
(
vk_t
)
==
SHMEM_SIZE
);
static_assert
(
WARP_N
==
BLOCK_N
);
assert
(
binfo
.
numBlocksN
%
3
==
0
);
const
int
num_heads
=
binfo
.
numBlocksN
/
3
*
2
*
(
WARP_N
/
(
LITELA_HEAD_DIM
*
2
));
const
int
batch_id
=
binfo
.
bm
/
num_blocks_per_batch
;
for
(
int
head_id
=
0
;
head_id
<
WARP_N
/
(
LITELA_HEAD_DIM
*
2
);
head_id
++
)
{
const
int
global_head_id
=
(
binfo
.
bn
-
binfo
.
numBlocksN
/
3
)
*
(
WARP_N
/
(
LITELA_HEAD_DIM
*
2
))
+
head_id
;
float
*
out_vk_current_head
=
out_vk
+
(
batch_id
*
num_heads
+
global_head_id
)
*
(
LITELA_HEAD_DIM
+
1
)
*
LITELA_HEAD_DIM
;
for
(
int
i
=
laneId
;
i
<
sizeof
(
shmem_vk
)
/
sizeof
(
float
)
/
NUM_WARPS
;
i
+=
WARP_SIZE
)
{
*
((
&
shmem_vk
[
warpId
][
0
][
0
])
+
i
)
=
0
;
}
__syncwarp
();
for
(
int
tile_v
=
0
;
tile_v
<
LITELA_V_TILES
;
tile_v
++
)
{
for
(
int
tile_k
=
0
;
tile_k
<
LITELA_K_TILES
;
tile_k
++
)
{
packed_f32psum_t
attn_sum
=
{
0
};
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
k
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
tile_k
];
packed_fpsum_t
v
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
LITELA_HEAD_DIM
/
16
+
tile_v
];
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
k
.
data
[
j
]
=
__hmax2
(
k
.
data
[
j
],
half2_t
(
0
,
0
));
// relu
}
attn_sum
=
mma_litela
(
k
,
v
,
attn_sum
);
}
const
int
row
=
tile_v
*
16
+
laneId
/
4
;
const
int
col
=
tile_k
*
16
+
laneId
%
4
*
2
;
shmem_vk
[
warpId
][
row
+
0
][
col
+
0
]
=
attn_sum
.
data
[
0
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
1
]
=
attn_sum
.
data
[
1
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
0
]
=
attn_sum
.
data
[
2
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
1
]
=
attn_sum
.
data
[
3
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
8
]
=
attn_sum
.
data
[
4
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
9
]
=
attn_sum
.
data
[
5
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
8
]
=
attn_sum
.
data
[
6
];
shmem_vk
[
warpId
][
row
+
8
][
col
+
9
]
=
attn_sum
.
data
[
7
];
}
}
for
(
int
tile_k
=
0
;
tile_k
<
LITELA_K_TILES
;
tile_k
++
)
{
packed_f32psum_t
attn_sum
=
{
0
};
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
k
=
fpsum
[
i
*
WARP_N_TILES
+
head_id
*
(
LITELA_HEAD_DIM
*
2
)
/
16
+
tile_k
];
packed_fpsum_t
v
=
{};
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
k
.
data
[
j
]
=
__hmax2
(
k
.
data
[
j
],
half2_t
(
0
,
0
));
// relu
}
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
v
.
data
[
i
]
=
half2_t
(
1
,
1
);
}
// if (laneId < 4) {
// v.data[0] = half2_t(1, 1);
// v.data[2] = half2_t(1, 1);
// }
// if (laneId % 4 == 0) {
// v.data[0] = half2_t(1, 0);
// v.data[1] = half2_t(1, 0);
// }
attn_sum
=
mma_litela
(
k
,
v
,
attn_sum
);
}
const
int
row
=
LITELA_HEAD_DIM
+
laneId
/
4
;
const
int
col
=
tile_k
*
16
+
laneId
%
4
*
2
;
if
(
laneId
<
4
)
{
shmem_vk
[
warpId
][
row
+
0
][
col
+
0
]
=
attn_sum
.
data
[
0
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
1
]
=
attn_sum
.
data
[
1
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
8
]
=
attn_sum
.
data
[
4
];
shmem_vk
[
warpId
][
row
+
0
][
col
+
9
]
=
attn_sum
.
data
[
5
];
}
}
__syncthreads
();
for
(
int
i
=
warpId
;
i
<
LITELA_HEAD_DIM
+
1
;
i
+=
NUM_WARPS
)
{
for
(
int
j
=
laneId
;
j
<
LITELA_HEAD_DIM
;
j
+=
WARP_SIZE
)
{
float
sum
=
0
;
for
(
int
k
=
0
;
k
<
NUM_WARPS
;
k
++
)
{
sum
+=
shmem_vk
[
k
][
i
][
j
];
}
reduce_add
(
&
out_vk_current_head
[
i
*
LITELA_HEAD_DIM
+
j
],
sum
);
}
}
__syncthreads
();
}
}
struct
Arguments
{
half_t
*
out_q
;
float
*
out_vk
;
int
num_blocks_per_batch
;
int
actualM
;
};
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
if
(
bn
<
binfo
.
numBlocksN
/
3
)
{
fpsum
=
apply_act
(
fpsum
,
[](
half_t
x
)
{
return
__hmax
(
x
,
0
);
});
// relu
return
EpilogueDefault
()(
binfo
,
fpsum
,
M
,
N
/
3
,
K
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
out_q
,
.
actualM
=
args
.
actualM
,
.
actualN
=
N
/
3
,
});
}
return
apply_litela
(
binfo
,
fpsum
,
args
.
out_vk
,
args
.
num_blocks_per_batch
);
}
// each thread block mults BlockSize*HEAD_DIM q and (HEAD_DIM+1)*HEAD_DIM vk, in-place writes back to q
// q: [batch_size, #blocks, block_size, #heads, HEAD_DIM]
// vk: [batch_size, #heads, HEAD_DIM+1, HEAD_DIM]
struct
vk_mul_q_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
// FIXME FIXME FIXME
__device__
void
operator
()(
half_t
*
q
,
const
float
*
vk
,
float
eps
,
int
num_tokens
)
{
const
int
block_id
=
blockIdx
.
x
;
const
int
head_id
=
blockIdx
.
y
;
const
int
batch_id
=
blockIdx
.
z
;
const
int
num_blocks
=
gridDim
.
x
;
const
int
num_heads
=
gridDim
.
y
;
const
int
block_size
=
blockDim
.
x
;
bool
pred
=
block_id
*
block_size
+
threadIdx
.
x
<
num_tokens
;
half_t
*
localq
=
&
q
[(((
batch_id
*
num_blocks
+
block_id
)
*
block_size
+
threadIdx
.
x
)
*
num_heads
+
head_id
)
*
LITELA_HEAD_DIM
];
const
float
*
localvk
=
&
vk
[(
batch_id
*
num_heads
+
head_id
)
*
(
LITELA_HEAD_DIM
+
1
)
*
LITELA_HEAD_DIM
];
// half_t *localout = &out[(((batch_id * num_blocks + block_id) * block_size + threadIdx.x) * num_heads + head_id) * LITELA_HEAD_DIM];
using
packed_q
=
std
::
array
<
half_t
,
8
>
;
using
packed_vk
=
std
::
array
<
float
,
4
>
;
half_t
qblock
[
LITELA_HEAD_DIM
];
for
(
int
i
=
0
;
i
<
LITELA_HEAD_DIM
;
i
+=
sizeof
(
packed_q
)
/
sizeof
(
half_t
))
{
if
(
pred
)
{
*
reinterpret_cast
<
packed_q
*>
(
&
qblock
[
i
])
=
load
(
reinterpret_cast
<
const
packed_q
*>
(
&
localq
[
i
]));
}
}
float
outblock
[
LITELA_HEAD_DIM
+
1
];
#pragma unroll
for
(
int
j
=
0
;
j
<
LITELA_HEAD_DIM
+
1
;
j
++
)
{
outblock
[
j
]
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
LITELA_HEAD_DIM
;
i
+=
sizeof
(
packed_vk
)
/
sizeof
(
float
))
{
packed_vk
vkpack
=
load
(
reinterpret_cast
<
const
packed_vk
*>
(
&
localvk
[
j
*
LITELA_HEAD_DIM
+
i
]));
#pragma unroll
for
(
int
k
=
0
;
k
<
vkpack
.
size
();
k
++
)
{
outblock
[
j
]
+=
(
float
)
qblock
[
i
+
k
]
*
vkpack
[
k
];
}
}
}
for
(
int
i
=
0
;
i
<
LITELA_HEAD_DIM
;
i
+=
sizeof
(
packed_q
)
/
sizeof
(
half_t
))
{
packed_q
opack
;
for
(
int
k
=
0
;
k
<
opack
.
size
();
k
++
)
{
opack
[
k
]
=
__fdividef
(
outblock
[
i
+
k
],
outblock
[
LITELA_HEAD_DIM
]
+
eps
);
}
if
(
pred
)
{
store
(
reinterpret_cast
<
packed_q
*>
(
&
localq
[
i
]),
opack
);
}
}
}
};
};
template
<
typename
Epilogue
,
bool
ACT_UNSIGNED
>
struct
gemm_w4a4_kernel
{
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
...
...
@@ -2167,21 +1071,31 @@ public:
}
};
template
<
typename
Epilogue
>
struct
test_epilogue_kernel
{
template
<
bool
fuse_glu
,
bool
use_fp4
>
struct
quantize_w4a4_fuse_lora_kernel
{
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
static
constexpr
int
MIN_ARCH
=
std
::
is_same_v
<
half_t
,
__nv_bfloat16
>
?
800
:
750
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
f
alse
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
Base
::
template
load_act_to_fpsum
<
f
use_glu
>
::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
struct
Arguments
{
const
half_t
*
input
;
half_t
*
output
;
const
packed_wscale_t
*
smooth_factor
;
packed_act_t
*
output
;
oscales_t
*
oscales
;
const
packed_fpsum_t
*
lora_wgt_down
;
float
*
lora_act
;
int
lora_rank
;
// aligned to BLOCK_M and BLOCK_N
int
M
,
N
;
int
M
,
N
;
// N should be the actual K in the next GEMM (needs /2 if fuse_glu)
// the actual M and N (no need to /2 if fuse_glu)
int
actualM
,
actualN
;
typename
Epilogue
::
Arguments
argsEpilogu
e
;
bool
alwaysfals
e
;
};
__device__
__forceinline__
...
...
@@ -2199,30 +1113,48 @@ public:
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
m_offset
=
bm
*
BLOCK_M
+
warpId
*
WARP_M
;
const
int
n_offset
=
bn
*
BLOCK_N
;
const
int
n_offset
=
bn
*
BLOCK_N
*
(
fuse_glu
?
2
:
1
)
;
extern
__shared__
uint8_t
shmem
[];
fpsum_warp
fpsum
;
load_act_to_fpsum
<
f
alse
>
()(
Base
::
template
load_act_to_fpsum
<
f
use_glu
>()(
args
.
input
+
m_offset
*
args
.
actualN
+
n_offset
,
args
.
actualN
,
args
.
actualM
-
m_offset
,
args
.
actualN
-
n_offset
,
fpsum
,
shmem
+
warpId
*
SHMEM_PER_WARP
// args.smooth_factor ? args.smooth_factor + n_offset : nullptr
);
Epilogue
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
args
.
argsEpilogue
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
// for (int i = 0; i < 16; i++) {
// printf("bm=%d bn=%d warp=%d lane=%d fpsum[%d][0:1]=%f %f\n",
// bm, bn, warpId, threadIdx.x % WARP_SIZE, i,
// (float)fpsum[i].data[0].x, (float)fpsum[i].data[0].y);
// }
using
EpilogueLoraDown
=
typename
Lora
<
Config
>::
EpilogueLoraDown
;
EpilogueLoraDown
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
args
.
lora_wgt_down
,
.
lora_act
=
args
.
lora_act
,
.
rank
=
args
.
lora_rank
,
.
alwaysfalse
=
args
.
alwaysfalse
,
});
EpilogueDefault
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueDefault
::
Arguments
{
.
out
=
args
.
output
,
.
actualM
=
args
.
actualM
,
.
actualN
=
args
.
actualN
,
EpilogueQuantize
<
false
,
false
,
use_fp4
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
,
use_fp4
>::
Arguments
{
.
qout
=
args
.
output
,
.
oscales
=
args
.
oscales
,
.
shift_value
=
0
,
.
smooth_factor
=
args
.
smooth_factor
});
}
};
};
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
e05256c8
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
Config
,
bool
USE_FP4
>
class
GEMM_W4A4_Launch
{
using
GEMM
=
GEMM_W4A4
<
Config
>
;
// using LoraRanks = std::integer_sequence<int, 0, 32>;
// using LoraRanks = std::integer_sequence<int, 0, 32, 48, 64, 80, 96, 112, 128, 160, 176, 224>;
using
LoraRanks
=
std
::
integer_sequence
<
int
,
0
,
32
,
48
,
64
,
80
,
96
,
112
,
128
,
144
,
160
,
176
,
192
,
208
,
224
>
;
// using LoraRanks = std::integer_sequence<int,
// 0, 32, 48, 64, 80, 96, 112, 128, 144, 160,
// 176, 192, 208, 224, 240, 256, 272, 288, 304, 320,
// 336, 352, 368, 384, 400, 416, 432, 448, 464, 480,
// 496, 512>;
using
Epilogues
=
Epilogues
<
Config
>
;
using
Lora
=
Lora
<
Config
>
;
using
packed_act_t
=
typename
GEMM
::
packed_act_t
;
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
...
...
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
e05256c8
...
...
@@ -191,76 +191,82 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
lora_up
.
valid
()
==
lora_act_in
.
valid
());
assert
(
lora_down
.
valid
()
==
lora_act_out
.
valid
());
if
(
!
lora_up
.
valid
())
{
assert
(
!
lora_down
.
valid
());
const
int
rank_up
=
lora_up
.
valid
()
?
lora_up
.
shape
[
1
]
:
0
;
const
int
rank_down
=
lora_down
.
valid
()
?
lora_down
.
shape
[
1
]
:
0
;
if
(
rank_up
==
0
)
{
assert
(
rank_down
==
0
);
return
launch_bias
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
MidEpilogue
,
NextEpilogue
>
>
({
midArgs
,
nextArgs
});
}
const
int
rank_up
=
lora_up
.
shape
[
1
];
assert
(
rank_up
%
16
==
0
);
assert
(
lora_up
.
shape
[
0
]
==
N
);
// assert(lora_up.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_in
.
shape
[
0
]
==
M
);
assert
(
lora_act_in
.
shape
[
1
]
==
rank_up
);
dispatchVal
(
rank_up
,
LoraRanks
(),
[
&
]
<
int
RANK_UP
>
()
{
using
LoraUp
=
typename
GEMM
::
Lora
<
RANK_UP
>
;
using
scale_t
=
typename
LoraUp
::
scale_t
;
using
LoraUp
=
Lora
;
using
scale_t
=
typename
LoraUp
::
scale_t
;
scale_t
scales
;
if
constexpr
(
scales
.
size
()
>
0
)
{
assert
(
lora_scales
.
size
()
>=
scales
.
size
());
for
(
size_t
i
=
0
;
i
<
scales
.
size
();
i
++
)
{
scales
[
i
]
=
lora_scales
[
i
];
}
}
if
(
!
lora_down
.
valid
())
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
scales
=
scales
,
},
midArgs
,
nextArgs
,
{}
});
scale_t
scales
;
if
constexpr
(
scales
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
scales
.
size
();
i
++
)
{
scales
[
i
]
=
i
<
lora_scales
.
size
()
?
lora_scales
[
i
]
:
0.0
f
;
}
}
const
int
rank_down
=
lora_down
.
shape
[
1
];
assert
(
rank_down
==
rank_up
);
assert
(
lora_down
.
shape
[
0
]
==
N
);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_out
.
shape
[
0
]
==
M
);
assert
(
lora_act_out
.
shape
[
1
]
==
rank_down
);
lora_act_out
.
zero_
();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
if
(
rank_down
==
0
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
.
scales
=
scales
,
.
alwaysfalse
=
false
,
},
midArgs
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
},
nextArgs
,
{}
});
}
// });
// assert(rank_down == rank_up);
assert
(
rank_down
%
16
==
0
);
assert
(
lora_down
.
shape
[
0
]
==
N
);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_out
.
shape
[
0
]
==
M
);
assert
(
lora_act_out
.
shape
[
1
]
==
rank_down
);
lora_act_out
.
zero_
();
// dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {
using
LoraDown
=
LoraUp
;
// GEMM::Lora<RANK_DOWN>;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
LoraUp
::
EpilogueLoraUp
,
MidEpilogue
,
typename
LoraDown
::
EpilogueLoraDown
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch_bias
.
template
operator
()
<
Epilogue
>({
typename
LoraUp
::
EpilogueLoraUp
::
Arguments
{
.
lora_act
=
lora_act_in
.
data_ptr
<
float
>
(),
.
lora_wgt_up
=
lora_up
.
data_ptr
<
packed_fpsum_t
>
(),
.
rank
=
rank_up
,
.
scales
=
scales
,
.
alwaysfalse
=
false
,
},
midArgs
,
typename
LoraDown
::
EpilogueLoraDown
::
Arguments
{
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
rank
=
rank_down
,
.
alwaysfalse
=
false
,
},
nextArgs
,
{}
});
// });
};
if
(
qout
.
valid
()
&&
oscales
.
valid
())
{
...
...
@@ -280,7 +286,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
Epilogues
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
...
...
@@ -289,7 +295,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
Epilogues
::
EpilogueGelu
>(
argsQuantize
,
{});
}
...
...
@@ -297,7 +303,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
out_vk
.
valid
());
using
Epilogue
=
typename
GEMM
::
EpilogueLiteLA
;
using
Epilogue
=
typename
Epilogues
::
EpilogueLiteLA
;
assert
(
out_vk
.
dtype
()
==
Tensor
::
FP32
);
assert
(
out_vk
.
ndims
()
==
4
);
...
...
@@ -334,7 +340,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
assert
(
rotary_emb
.
scalar_type
()
==
Tensor
::
FP32
);
assert
(
rotary_emb
.
ndims
()
==
3
);
assert
(
rotary_emb
.
shape
[
0
]
*
rotary_emb
.
shape
[
1
]
==
M
);
assert
(
rotary_emb
.
shape
[
2
]
==
GEMM
::
EpilogueRMSNormRope
::
HEAD_DIM
);
assert
(
rotary_emb
.
shape
[
2
]
==
Epilogues
::
EpilogueRMSNormRope
::
HEAD_DIM
);
// assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
// launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
...
...
@@ -348,8 +354,8 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
// .epsilon = 1e-6,
// }, {});
using
EpilogueRope
=
typename
GEMM
::
EpilogueRMSNormRope
;
auto
argsRope
=
typename
GEMM
::
EpilogueRMSNormRope
::
Arguments
{
using
EpilogueRope
=
typename
Epilogues
::
EpilogueRMSNormRope
;
auto
argsRope
=
typename
Epilogues
::
EpilogueRMSNormRope
::
Arguments
{
.
rotary_emb
=
rotary_emb
.
data_ptr
<
typename
EpilogueRope
::
packed_rotemb_t
>
(),
.
rmsnorm_weight_q
=
norm_q
.
data_ptr
<
half_t
>
(),
.
rmsnorm_weight_k
=
norm_k
.
data_ptr
<
half_t
>
(),
...
...
@@ -357,16 +363,16 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
};
if
(
out_q
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
GEMM
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
({
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
EpilogueRope
,
typename
Epilogues
::
EpiloguePackQKV
>,
typename
GEMM
::
EpilogueNop
>
({
argsRope
,
typename
GEMM
::
EpiloguePackQKV
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
typename
Epilogues
::
EpiloguePackQKV
::
Arguments
{
.
out_q
=
out_q
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_k
=
out_k
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
out_v
=
out_v
.
data_ptr
<
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
>
(),
.
actualM
=
attn_tokens
,
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
GEMM
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_q
=
int
(
out_q
.
stride
(
1
)
*
out_q
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_k
=
int
(
out_k
.
stride
(
1
)
*
out_k
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
.
strideHead_v
=
int
(
out_v
.
stride
(
1
)
*
out_v
.
scalar_size
()
/
sizeof
(
typename
Epilogues
::
EpiloguePackQKV
::
packed_qkv_t
)),
}
},
{});
}
else
{
...
...
@@ -401,7 +407,7 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
template
<
typename
Config
,
bool
USE_FP4
>
void
GEMM_W4A4_Launch
<
Config
,
USE_FP4
>::
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
)
{
using
Epilogue
=
typename
GEMM
::
EpilogueLiteLA
;
using
Epilogue
=
typename
Epilogues
::
EpilogueLiteLA
;
int
batch_size
=
vk
.
shape
[
0
];
int
num_heads
=
vk
.
shape
[
1
];
...
...
@@ -449,6 +455,8 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
const
int
rank
=
lora_down
.
shape
[
1
];
assert
(
rank
%
16
==
0
);
assert
(
lora_down
.
shape
[
0
]
==
N
);
// assert(lora_down.shape[1] == Lora::LORA_RANK);
assert
(
lora_act_out
.
shape
[
0
]
==
M
);
...
...
@@ -458,34 +466,36 @@ void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input
dim3
grid
(
M
/
GEMM
::
BLOCK_M
,
N
/
GEMM
::
BLOCK_N
);
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
});
// dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
// using Lora = typename GEMM::Lora<RANK>;
using
kernel
=
typename
GEMM
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
,
getCurrentCUDAStream
()
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_rank
=
rank
,
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
.
alwaysfalse
=
false
,
}
);
checkCUDA
(
cudaGetLastError
());
});
// });
}
template
<
typename
Config
,
bool
USE_FP4
>
...
...
src/kernels/zgemm/gemm_w4a4_test.cu
View file @
e05256c8
#include "zgemm.h"
#include "gemm_w4a4.cuh"
#include "epilogues.cuh"
namespace
nunchaku
::
kernels
{
...
...
@@ -10,7 +11,7 @@ void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k
assert
(
input
.
shape
.
dataExtent
==
output
.
shape
.
dataExtent
);
assert
(
input
.
scalar_type
()
==
Tensor
::
FP16
);
using
GEMM
=
GEMM_W4A4
<
GEMMConfig_W4A4_FP16
>
;
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpilogueRMSNormRope
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
...
...
@@ -51,7 +52,7 @@ void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int n
Tensor
output
=
Tensor
::
empty_like
(
input
);
using
GEMM
=
GEMM_W4A4
<
GEMMConfig_W4A4_FP16
>
;
using
GEMM
=
Epilogues
<
GEMMConfig_W4A4_FP16
>
;
using
Epilogue
=
GEMM
::
EpiloguePackQKV
;
assert
(
M
%
GEMM
::
BLOCK_M
==
0
);
...
...
src/kernels/zgemm/lora.cuh
0 → 100644
View file @
e05256c8
#pragma once
#include "gemm_base.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
class
Lora
;
#ifndef __INTELLISENSE__
template
<
typename
Config
>
class
Lora
:
public
GEMMBase
<
Config
>
{
#else
template
<
>
class
Lora
<
GEMMConfig_W4A4_FP16
>
:
public
GEMMBase
<
GEMMConfig_W4A4_FP16
>
{
using
Config
=
GEMMConfig_W4A4_FP16
;
#endif
public:
IMPORT_GEMM_BASE
(
Config
);
public:
static
constexpr
int
MAX_RANK
=
1024
;
static
constexpr
int
WARP_R
=
16
;
// static constexpr int LORA_RANK = rank;
static
constexpr
int
LORA_M_TILES
=
WARP_M
/
16
;
static
constexpr
int
LORA_R_TILES
=
WARP_R
/
16
;
static
constexpr
int
LORA_N_TILES
=
WARP_N
/
16
;
static_assert
(
LORA_M_TILES
==
WARP_M_TILES
);
static_assert
(
LORA_N_TILES
==
WARP_N_TILES
);
// lora_down: [WARP_M, WARP_N] x [WARP_N, R] (row-wise) = [WARP_M, R]
// lora up: [WARP_M, R] x [WARP_N, R] (col-wise) = [WARP_M, WARP_N]
// we use fp32 for lora activation since there's no bf16 reduction in sm_89 :(
using
lora_act_warp
=
std
::
array
<
packed_f32psum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_act16_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_M_TILES
*
LORA_R_TILES
>
;
using
lora_wgt_warp
=
std
::
array
<
packed_fpsum_t
,
LORA_N_TILES
*
LORA_R_TILES
>
;
using
scale_t
=
std
::
array
<
float
,
MAX_RANK
/
16
>
;
// lora_wgt: [N / 16, rank / WARP_R, LORA_R_TILES, WARP_SIZE] of packed_fpsum_t
// [N / 16, rank / 16, WARP_SIZE]
__device__
__forceinline__
static
void
load_lora_wgt
(
const
packed_fpsum_t
*
ptr
,
int
rtile
,
int
rank
,
lora_wgt_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
packed_fpsum_t
*
ptr_lane
=
&
ptr
[
rtile
*
LORA_R_TILES
*
WARP_SIZE
+
laneId
];
const
int
stride_ntile
=
rank
/
16
*
WARP_SIZE
;
unrolled_loop
<
LORA_N_TILES
>
([
&
]
<
int
n
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
()
{
constexpr
int
roffset
=
r
*
WARP_SIZE
;
const
int
noffset
=
n
*
stride_ntile
;
result
[
n
*
LORA_R_TILES
+
r
]
=
load_pred
(
ptr_lane
+
noffset
+
roffset
,
pred
);
});
});
}
// lora_act: [M / BLOCK_M, rank / WARP_R, NUM_WARPS, LORA_M_TILES, LORA_R_TILES, 8, WARP_SIZE] of float
__device__
__forceinline__
static
void
load_lora_act
(
const
float
*
ptr
,
int
rtile
,
lora_act_warp
&
result
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
LORA_M_TILES
>
([
&
]
<
int
m
>
()
{
unrolled_loop
<
LORA_R_TILES
>
([
&
]
<
int
r
>
{
constexpr
int
i
=
m
*
LORA_R_TILES
+
r
;
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
result
[
i
].
data
[
j
]
=
load_pred
(
ptrlane
+
offset
,
pred
);
// * scales[rtile * LORA_R_TILES + r];
});
// CHECK_NAN(tmp, "load_lora_act.tmp");
});
});
}
// no vector reduction in sm_89 :(
__device__
__forceinline__
static
void
reduce_lora_act
(
float
*
ptr
,
int
rtile
,
lora_act_warp
val
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
float
*
ptrlane
=
&
ptr
[(
rtile
*
NUM_WARPS
+
warpId
)
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
+
laneId
];
unrolled_loop
<
LORA_M_TILES
*
LORA_R_TILES
>
([
&
]
<
int
i
>
()
{
unrolled_loop
<
8
>
([
&
]
<
int
j
>
()
{
constexpr
int
offset
=
i
*
8
*
WARP_SIZE
+
j
*
WARP_SIZE
;
reduce_add_pred
(
&
ptrlane
[
offset
],
val
[
i
].
data
[
j
],
pred
);
});
});
}
// __device__ __forceinline__
// static void reduce_lora_act(float *ptr, lora_act_warp val, int m) {
// const int laneId = threadIdx.x % WARP_SIZE;
// float *ptrlane = ptr + laneId + m * LORA_R_TILES * 8 * WARP_SIZE;
// unrolled_loop<LORA_R_TILES>([&]<int r>() {
// unrolled_loop<8>([&]<int j>() {
// constexpr int offset = r * 8 * WARP_SIZE + j * WARP_SIZE;
// reduce_add(&ptrlane[offset], val[m * LORA_R_TILES + r].data[j]);
// });
// });
// }
struct
EpilogueLoraUp
{
struct
Arguments
{
const
float
*
lora_act
;
const
packed_fpsum_t
*
lora_wgt_up
;
int
rank
;
scale_t
scales
;
bool
alwaysfalse
;
};
__device__
__forceinline__
static
void
apply_lora_up
(
fpsum_warp
&
fpsum
,
const
float
*
act
,
const
packed_fpsum_t
*
wgt
,
const
scale_t
&
scales
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_act_warp
lora_act
[
NUM_STAGES
];
// 32
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
int
dummy
=
0
;
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
const
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
load_lora_act
(
act
,
0
,
lora_act
[
k
],
pred
);
load_lora_wgt
(
wgt
,
0
,
rank
,
lora_wgt
[
k
],
pred
);
}
f32psum_warp
f32psum
=
packed_fp16_to_fp32
(
fpsum
);
// 128
auto
compute
=
[
&
scales
](
lora_act_warp
A
,
lora_wgt_warp
W
,
f32psum_warp
&
f32psum
,
int
rtile
)
ALWAYSINLINE
{
lora_act16_warp
A_fp16
;
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
packed_f32psum_t
pack
=
A
[
m
*
LORA_R_TILES
+
r
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
{
pack
.
data
[
j
]
*=
scales
[
rtile
*
LORA_R_TILES
+
r
];
}
A_fp16
[
m
*
LORA_R_TILES
+
r
]
=
packed_fp32_to_fp16
(
pack
);
}
}
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
CHECK_NAN
(
lora_act
[
m
*
LORA_R_TILES
+
r
],
"lora_act"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"lora_wgt"
);
f32psum
[
m
*
WARP_N_TILES
+
n
]
=
mma_f16xf16_f32
(
A_fp16
[
m
*
LORA_R_TILES
+
r
],
W
[
n
*
LORA_R_TILES
+
r
],
f32psum
[
m
*
WARP_N_TILES
+
n
]);
}
}
}
};
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
}
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
if
(
alwaysfalse
)
{
act
+=
kernels
::
bit_cast
<
int
>
(
lora_act
[
k2
][
0
].
data
[
0
]);
}
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
load_lora_act
(
act
,
nextk
,
lora_act
[
idx
],
pred
);
load_lora_wgt
(
wgt
,
nextk
,
rank
,
lora_wgt
[
idx
],
pred
);
compute
(
lora_act
[
k2
],
lora_wgt
[
k2
],
f32psum
,
k1
+
k2
);
}
}
// NVCC does not know rank > 0 :(
// it will generate a branch instruction to skip the initial load
// the branch splits the basic blocks and prevents the overlap of memory access and computing (packed_fp16_to_fp32)
// add fake dependency of loaded data so NVCC will not skip the load
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
for
(
auto
&&
data
:
lora_act
[
k
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
fpsum
=
packed_fp32_to_fp16
(
f32psum
);
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
CHECK_NAN
(
fpsum
,
"fpsum"
);
apply_lora_up
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_up
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
scales
,
args
.
rank
,
args
.
alwaysfalse
);
CHECK_NAN
(
fpsum
,
"fpsum"
);
}
};
struct
EpilogueLoraDown
{
struct
Arguments
{
const
packed_fpsum_t
*
lora_wgt_down
;
float
*
lora_act
;
int
rank
;
bool
alwaysfalse
;
};
__device__
__forceinline__
static
void
apply_lora_down
(
fpsum_warp
&
fpsum
,
float
*
act
,
const
packed_fpsum_t
*
wgt
,
int
rank
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
lora_wgt_warp
lora_wgt
[
NUM_STAGES
];
// 64
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
// we have rank > 0
bool
pred
=
k
==
0
?
true
:
k
<
rank
/
WARP_R
;
load_lora_wgt
(
wgt
,
0
,
rank
,
lora_wgt
[
k
],
pred
);
}
auto
compute
=
[](
lora_wgt_warp
W
,
fpsum_warp
fpsum
)
->
lora_act_warp
{
lora_act_warp
lora_act
;
lora_act
.
fill
(
packed_f32psum_t
::
zeros
());
#pragma unroll
for
(
int
m
=
0
;
m
<
LORA_M_TILES
;
m
++
)
{
#pragma unroll
for
(
int
n
=
0
;
n
<
LORA_N_TILES
;
n
++
)
{
#pragma unroll
for
(
int
r
=
0
;
r
<
LORA_R_TILES
;
r
++
)
{
auto
&
psum
=
lora_act
[
m
*
LORA_R_TILES
+
r
];
CHECK_NAN
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
"apply_lora_down.fpsum"
);
CHECK_NAN
(
lora_wgt
[
n
*
LORA_R_TILES
+
r
],
"apply_lora_down.lora_wgt"
);
psum
=
mma_f16xf16_f32
(
fpsum
[
m
*
WARP_N_TILES
+
n
],
W
[
n
*
LORA_R_TILES
+
r
],
psum
);
CHECK_NAN
(
psum
,
"apply_lora_down.psum"
);
}
}
}
return
lora_act
;
};
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
rank
/
WARP_R
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
if
(
k1
+
k2
>=
rank
/
WARP_R
)
{
break
;
}
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
rank
/
WARP_R
;
if
(
alwaysfalse
)
{
wgt
+=
kernels
::
bit_cast
<
int
>
(
lora_wgt
[
k2
][
0
].
data
[
0
]);
}
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
load_lora_wgt
(
wgt
,
nextk
,
rank
,
lora_wgt
[
idx
],
pred
);
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
lora_act_warp
lora_act
=
compute
(
lora_wgt
[
k2
],
fpsum
);
reduce_lora_act
(
act
,
k1
+
k2
,
lora_act
,
true
);
}
}
#pragma unroll
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
#pragma unroll
for
(
auto
&&
data
:
lora_wgt
[
k
])
{
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
dummy
^=
kernels
::
bit_cast
<
int
>
(
data
.
data
[
i
]);
}
}
}
unused_var
(
dummy
,
alwaysfalse
);
}
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
Arguments
&
args
)
{
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
apply_lora_down
(
fpsum
,
args
.
lora_act
+
bm
*
(
args
.
rank
/
WARP_R
)
*
(
NUM_WARPS
*
LORA_M_TILES
*
LORA_R_TILES
*
8
*
WARP_SIZE
),
args
.
lora_wgt_down
+
bn
*
(
BLOCK_N
/
16
)
*
(
args
.
rank
/
16
)
*
WARP_SIZE
,
args
.
rank
,
args
.
alwaysfalse
);
}
};
};
};
// namespace nunchaku::kernels
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment