Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
b1b44398
Commit
b1b44398
authored
Feb 26, 2025
by
Samuel Tesfai
Browse files
Fixing merges
parents
004e4e31
4b9c2e03
Changes
55
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
849 additions
and
192 deletions
+849
-192
src/SanaModel.cpp
src/SanaModel.cpp
+20
-15
src/SanaModel.h
src/SanaModel.h
+5
-4
src/Serialization.cpp
src/Serialization.cpp
+2
-0
src/Tensor.h
src/Tensor.h
+4
-1
src/common.h
src/common.h
+10
-0
src/interop/torch.cpp
src/interop/torch.cpp
+4
-2
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+4
-2
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+102
-38
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+93
-0
src/kernels/zgemm/gemm_w4a4.cu
src/kernels/zgemm/gemm_w4a4.cu
+22
-5
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+412
-28
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+7
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+157
-91
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+2
-2
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+5
-2
No files found.
src/SanaModel.cpp
View file @
b1b44398
...
...
@@ -8,11 +8,11 @@
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
dtype
,
device
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
...
...
@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S
;
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
dtype
,
device
);
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
registerChildren
(
pag_to_v
.
value
(),
"pag_to_v"
);
}
}
...
...
@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
);
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{}
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
...
...
@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return
out
;
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
)
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)
...
...
@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return
out_proj
.
forward
(
attn_output
);
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
dtype
,
device
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
dtype
,
device
)
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)
...
...
@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
dtype
,
device
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
...
...
@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
ceilDiv
(
int
(
round
(
config
.
expand_ratio
*
inner_dim
)),
64
)
*
64
,
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
config
.
use_fp4
,
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
...
...
src/SanaModel.h
View file @
b1b44398
...
...
@@ -7,7 +7,7 @@
class
SanaLinearAttention
:
public
Module
{
public:
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
out
=
{});
Tensor
forward_pag
(
Tensor
x
,
bool
cfg
);
...
...
@@ -25,7 +25,7 @@ private:
class
MultiHeadCrossAttention
:
public
Module
{
public:
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
);
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
);
...
...
@@ -41,7 +41,7 @@ private:
class
SanaGLUMBConv
:
public
Module
{
public:
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
int
H
,
int
W
);
...
...
@@ -57,7 +57,7 @@ private:
class
SanaLinearTransformerBlock
:
public
Module
{
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
...
...
@@ -83,6 +83,7 @@ struct SanaConfig {
int
num_cross_attention_heads
;
double
expand_ratio
;
std
::
vector
<
int
>
pag_layers
;
bool
use_fp4
;
};
class
SanaModel
:
public
Module
{
...
...
src/Serialization.cpp
View file @
b1b44398
...
...
@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() {
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
};
auto
check
=
[](
bool
cond
,
std
::
source_location
location
=
std
::
source_location
::
current
())
{
...
...
src/Tensor.h
View file @
b1b44398
...
...
@@ -218,7 +218,8 @@ public:
enum
ScalarType
{
INVALID_SCALAR_TYPE
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
};
struct
TensorOptions
{
...
...
@@ -546,6 +547,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{
FP16
,
2
},
{
FP32
,
4
},
{
BF16
,
2
},
{
FP8_E4M3
,
1
},
{
FP8_E5M2
,
1
},
};
struct
TensorsProvider
{
...
...
src/common.h
View file @
b1b44398
...
...
@@ -9,6 +9,7 @@
#include <memory>
#include <source_location>
#include <vector>
#include <list>
#include <stack>
#include <map>
#include <unordered_map>
...
...
@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) {
return
(
a
+
b
-
1
)
/
b
;
}
template
<
typename
T
>
constexpr
int
log2Up
(
T
value
)
{
if
(
value
<=
0
)
return
0
;
if
(
value
==
1
)
return
0
;
return
log2Up
((
value
+
1
)
/
2
)
+
1
;
}
struct
CUBLASWrapper
{
cublasHandle_t
handle
=
nullptr
;
...
...
src/interop/torch.cpp
View file @
b1b44398
...
...
@@ -28,8 +28,9 @@ Tensor from_torch(at::Tensor input) {
{
at
::
ScalarType
::
Float
,
Tensor
::
FP32
},
{
at
::
ScalarType
::
Half
,
Tensor
::
FP16
},
{
at
::
ScalarType
::
BFloat16
,
Tensor
::
BF16
},
{
at
::
ScalarType
::
Short
,
Tensor
::
INT16
},
{
at
::
ScalarType
::
Float8_e4m3fn
,
Tensor
::
FP8_E4M3
},
{
at
::
ScalarType
::
Float8_e5m2
,
Tensor
::
FP8_E5M2
},
};
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
...
...
@@ -55,8 +56,9 @@ at::Tensor to_torch(Tensor input) {
{
Tensor
::
FP32
,
at
::
ScalarType
::
Float
},
{
Tensor
::
FP16
,
at
::
ScalarType
::
Half
},
{
Tensor
::
BF16
,
at
::
ScalarType
::
BFloat16
},
{
Tensor
::
INT16
,
at
::
ScalarType
::
Short
},
{
Tensor
::
FP8_E4M3
,
at
::
ScalarType
::
Float8_e4m3fn
},
{
Tensor
::
FP8_E5M2
,
at
::
ScalarType
::
Float8_e5m2
},
};
c10
::
TensorOptions
opts
(
mapType
.
at
(
input
.
scalar_type
()));
...
...
src/kernels/awq/gemv_awq.cu
View file @
b1b44398
...
...
@@ -140,8 +140,10 @@ __global__ void gemv_kernel(
for
(
int
i
=
0
;
i
<
Num
;
++
i
)
psum
[
i
]
=
static_cast
<
accum_t
>
(
0.
f
);
extern
__shared__
uint8_t
shmem
[];
float
(
*
out_smem
)[
Num
*
kInterleave
]
=
reinterpret_cast
<
float
(
*
)[
Num
*
kInterleave
]
>
(
shmem
);
// extern __shared__ uint8_t shmem[];
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
__shared__
float
out_smem
[
BlockSize
/
WARP_SIZE
*
2
][
Num
*
kInterleave
];
const
int
blk_row_offset
=
blockIdx
.
x
*
NPerBlock
*
kInterleave
;
const
int
thd_row_offset
=
(
threadIdx
.
x
/
kThreadsNumPerTile
)
%
kInterleave
;
...
...
src/kernels/zgemm/gemm_base.cuh
View file @
b1b44398
...
...
@@ -319,10 +319,10 @@ public:
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
if
(
pred
)
{
//
if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out
[
i
]
=
load
(
&
act
[((
k
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
]);
}
out
[
i
]
=
load
_pred
(
&
act
[((
k
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
]
,
pred
);
//
}
}
}
...
...
@@ -336,12 +336,12 @@ public:
// int offset = K / WARP_K * WARP_SIZE;
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_N_TILES
;
i
++
)
{
if
(
pred
)
{
//
if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out
[
i
]
=
load
(
&
ptr
[
i
*
WARP_SIZE
]);
out
[
i
]
=
load
_pred
(
&
ptr
[
i
*
WARP_SIZE
]
,
pred
);
// ptr += offset;
}
//
}
}
}
...
...
@@ -352,11 +352,11 @@ public:
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
ASCALES_NUM_PACKS
;
i
++
)
{
if
(
pred
&&
laneId
<
ASCALES_VALID_LANES
)
{
//
if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out
[
i
]
=
ascales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
+
i
*
ASCALES_VALID_LANES
+
laneId
];
out
[
i
]
=
load_pred
(
&
ascales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
+
i
*
ASCALES_VALID_LANES
+
laneId
]
,
pred
&&
laneId
<
ASCALES_VALID_LANES
)
;
}
//
}
}
}
...
...
@@ -373,13 +373,13 @@ public:
#pragma unroll
for
(
int
i
=
0
;
i
<
WSCALES_NUM_PACKS
;
i
++
)
{
if
(
pred
&&
laneId
<
WSCALES_VALID_LANES
)
{
//
if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out
[
i
]
=
load
(
&
wscales
[(
group
*
WSCALES_NUM_PACKS
+
i
)
*
WSCALES_VALID_LANES
+
laneId
]);
out
[
i
]
=
load
_pred
(
&
wscales
[(
group
*
WSCALES_NUM_PACKS
+
i
)
*
WSCALES_VALID_LANES
+
laneId
]
,
pred
&&
laneId
<
WSCALES_VALID_LANES
);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
}
//
}
}
}
...
...
@@ -400,7 +400,7 @@ public:
return
__shfl_sync
(
~
0
,
block
[
packIdx
].
data
[
elementIdx
],
srcLane
);
}
template
<
typename
F
>
template
<
bool
FAST_I2F
=
false
,
typename
F
>
__device__
__forceinline__
static
void
apply_scales
(
F
&&
getpsum
,
ascale_warp
ascale
,
wscale_warp
wscale
,
fpsum_warp
&
fpsum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
...
@@ -429,12 +429,31 @@ public:
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
0
]),
__int2float_rn
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
2
]),
__int2float_rn
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
4
]),
__int2float_rn
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
6
]),
__int2float_rn
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
auto
scale_fma_normal
=
[
&
]()
ALWAYSINLINE
{
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
0
]),
__int2float_rn
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
2
]),
__int2float_rn
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
4
]),
__int2float_rn
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
6
]),
__int2float_rn
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
};
// should be faster on sm_80
auto
scale_fma_fast
=
[
&
]()
ALWAYSINLINE
{
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
0
]),
int2float_fast
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
2
]),
int2float_fast
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
4
]),
int2float_fast
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
6
]),
int2float_fast
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
if
constexpr
(
FAST_I2F
)
{
scale_fma_fast
();
}
else
{
scale_fma_normal
();
}
#else
scale_fma_normal
();
#endif
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
...
...
@@ -575,9 +594,9 @@ public:
(
plugins
(
i
*
INSN_M
+
row
,
pack
),
...);
bool
pred
=
i
*
INSN_M
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
}
//
if (pred) {
store
_pred
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
,
pred
);
//
}
}
__syncwarp
();
...
...
@@ -602,9 +621,9 @@ public:
(
plugins
(
i
*
INSN_M
+
8
+
row
,
pack
),
...);
bool
pred
=
i
*
INSN_M
+
8
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
8
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
}
//
if (pred) {
store
_pred
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
8
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
,
pred
);
//
}
}
__syncwarp
();
}
...
...
@@ -680,33 +699,61 @@ public:
}
};
template
<
bool
USE_BIAS
=
true
,
bool
USE_SCALE
=
false
>
struct
EpilogueBias
{
struct
Arguments
{
const
packed_wscale_t
*
bias
;
// [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const
packed_wscale_t
*
scale
;
};
__device__
__forceinline__
void
apply_bias
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
)
{
void
apply_bias
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
,
const
packed_wscale_t
*
scale
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
wscale_warp
b
;
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
wscale_warp
b
,
s
;
if
constexpr
(
USE_BIAS
)
{
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
}
if
constexpr
(
USE_SCALE
)
{
load_wscale
(
scale
,
0
,
N
,
s
,
true
);
}
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
half2_t
b1
=
broadcast_wscale
(
b
,
j
*
4
,
laneId
);
half2_t
b2
=
broadcast_wscale
(
b
,
j
*
4
+
2
,
laneId
);
half2_t
b1
,
b2
;
half2_t
s1
,
s2
;
if
constexpr
(
USE_BIAS
)
{
b1
=
broadcast_wscale
(
b
,
j
*
4
,
laneId
);
b2
=
broadcast_wscale
(
b
,
j
*
4
+
2
,
laneId
);
}
if
constexpr
(
USE_SCALE
)
{
s1
=
broadcast_wscale
(
s
,
j
*
4
,
laneId
);
s2
=
broadcast_wscale
(
s
,
j
*
4
+
2
,
laneId
);
}
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
auto
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
fsum
.
data
[
0
]
=
__hadd2
(
fsum
.
data
[
0
],
b1
);
fsum
.
data
[
1
]
=
__hadd2
(
fsum
.
data
[
1
],
b1
);
fsum
.
data
[
2
]
=
__hadd2
(
fsum
.
data
[
2
],
b2
);
fsum
.
data
[
3
]
=
__hadd2
(
fsum
.
data
[
3
],
b2
);
if
constexpr
(
USE_SCALE
&&
USE_BIAS
)
{
fsum
.
data
[
0
]
=
__hfma2
(
fsum
.
data
[
0
],
s1
,
b1
);
fsum
.
data
[
1
]
=
__hfma2
(
fsum
.
data
[
1
],
s1
,
b1
);
fsum
.
data
[
2
]
=
__hfma2
(
fsum
.
data
[
2
],
s2
,
b2
);
fsum
.
data
[
3
]
=
__hfma2
(
fsum
.
data
[
3
],
s2
,
b2
);
}
else
if
constexpr
(
USE_SCALE
)
{
fsum
.
data
[
0
]
=
__hmul2
(
fsum
.
data
[
0
],
s1
);
fsum
.
data
[
1
]
=
__hmul2
(
fsum
.
data
[
1
],
s1
);
fsum
.
data
[
2
]
=
__hmul2
(
fsum
.
data
[
2
],
s2
);
fsum
.
data
[
3
]
=
__hmul2
(
fsum
.
data
[
3
],
s2
);
}
else
if
constexpr
(
USE_BIAS
)
{
fsum
.
data
[
0
]
=
__hadd2
(
fsum
.
data
[
0
],
b1
);
fsum
.
data
[
1
]
=
__hadd2
(
fsum
.
data
[
1
],
b1
);
fsum
.
data
[
2
]
=
__hadd2
(
fsum
.
data
[
2
],
b2
);
fsum
.
data
[
3
]
=
__hadd2
(
fsum
.
data
[
3
],
b2
);
}
}
}
}
...
...
@@ -714,10 +761,13 @@ public:
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bn
=
binfo
.
bn
;
apply_bias
(
fpsum
,
M
,
N
,
K
,
args
.
bias
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
if
constexpr
(
USE_BIAS
||
USE_SCALE
)
{
apply_bias
(
fpsum
,
M
,
N
,
K
,
args
.
bias
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
args
.
scale
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
}
}
};
...
...
@@ -797,7 +847,21 @@ public:
using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \
using typename Base::EpilogueBias;
template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>; \
using Base::mma_f16xf16_f32; \
using Base::packed_fp32_to_fp16; \
using Base::packed_fp16_to_fp32; \
using Base::load_act; \
using Base::load_wgt; \
using Base::load_ascale; \
using Base::load_wscale; \
using Base::broadcast_wscale; \
using Base::broadcast_ascale; \
using Base::apply_scales; \
using Base::pack_ascales; \
using Base::pack_wscales; \
using Base::apply_act;
template
<
typename
kernel
,
typename
...
T
>
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
b1b44398
...
...
@@ -43,6 +43,41 @@ static T load(const T *addr) {
return
*
addr
;
}
template
<
typename
T
>
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
:
"=r"
(
data
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
T
result
;
if
(
pred
)
{
result
=
*
addr
;
}
return
result
;
}
template
<
bool
shmem
=
false
,
typename
T
>
__device__
__forceinline__
static
void
store
(
T
*
addr
,
T
val
)
{
...
...
@@ -76,6 +111,39 @@ static void store(T *addr, T val) {
*
addr
=
val
;
}
template
<
typename
T
>
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
if
(
pred
)
{
*
addr
=
val
;
}
}
__device__
__forceinline__
static
float2
half22float2
(
half2
val
)
{
return
__half22float2
(
val
);
...
...
@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) {
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
uint32_t
result
;
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
lo
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
hi
)
:
"f"
(
value
.
w
),
"f"
(
value
.
z
));
return
uint32_t
(
lo
)
|
(
uint32_t
(
hi
)
<<
16
);
}
__device__
__forceinline__
static
float
cuda_tanhf
(
float
x
)
{
float
result
;
...
...
@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) {
call
(
std
::
make_integer_sequence
<
int
,
cnt
>
());
}
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
float
fval
;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=f"
(
fval
)
:
"r"
(
val
),
"n"
(
0x7FFFFF
),
"n"
(
0x4B400000
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
fval
-
12582912.0
f
;
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4.cu
View file @
b1b44398
...
...
@@ -36,9 +36,23 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
)
{
invoke_launch
(
ascales
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
Tensor
::
ScalarType
dtype
=
Tensor
::
INVALID_SCALAR_TYPE
;
if
(
!
fp4
)
{
dtype
=
ascales
.
dtype
();
}
else
{
for
(
auto
tensor
:
{
out
,
bias
,
lora_up
,
lora_down
,
poolout
,
wcscales
})
{
if
(
tensor
.
valid
())
{
assert
(
dtype
==
Tensor
::
INVALID_SCALAR_TYPE
||
dtype
==
tensor
.
dtype
());
dtype
=
tensor
.
dtype
();
}
}
}
invoke_launch
(
dtype
,
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
act
,
wgt
,
...
...
@@ -61,7 +75,10 @@ void gemm_w4a4(
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
fuse_silu
,
fp4
,
alpha
,
wcscales
);
});
}
...
...
@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
});
}
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
});
}
...
...
src/kernels/zgemm/gemm_w4a4.cuh
View file @
b1b44398
This diff is collapsed.
Click to expand it.
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
b1b44398
...
...
@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_amscale_t
=
typename
GEMM
::
packed_amscale_t
;
using
packed_wmscale_t
=
typename
GEMM
::
packed_wmscale_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
...
...
@@ -38,9 +40,12 @@ public:
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
b1b44398
...
...
@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
)
{
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
...
...
@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// }, args);
using
kernel
=
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
;
auto
func
=
invoke_kernel
<
kernel
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
if
constexpr
(
!
USE_FP4
)
{
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_ascale_t
*
,
const
packed_wscale_t
*
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
if
constexpr
(
USE_FP4
)
{
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
// throw std::runtime_error("FP4 kernel is not available");
// }
});
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
}
assert
(
bias
.
numel
()
==
N
);
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
GEMM
::
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
packed_wscale_t
>
(),
},
nextArgs
,
{}
assert
(
!
bias
.
valid
()
||
bias
.
numel
()
==
N
);
assert
(
!
wcscales
.
valid
()
||
wcscales
.
numel
()
==
N
);
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
typename
EpilogueBias
::
Arguments
{
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
},
nextArgs
,
{}
});
});
});
};
// auto launch_bias = launch;
...
...
@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
});
constexpr
bool
USE_UNSIGNED
=
true
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
shift_value
=
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
...
...
@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
}
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
...
@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
if
(
fp4
)
{
assert
(
oscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
*
4
);
}
else
{
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
}
const
int
rank
=
lora_down
.
shape
[
1
];
...
...
@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
M
=
M
,
.
N
=
N
,
.
actualM
=
actualM
,
.
actualN
=
actualN
,
}
);
checkCUDA
(
cudaGetLastError
());
});
});
});
}
...
...
src/kernels/zgemm/gemm_w8a8.cu
View file @
b1b44398
...
...
@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
GEMM
::
EpilogueBias
<
true
,
false
>
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
nextArgs
,
...
...
src/kernels/zgemm/zgemm.h
View file @
b1b44398
...
...
@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment