Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
b1b44398
Commit
b1b44398
authored
Feb 26, 2025
by
Samuel Tesfai
Browse files
Fixing merges
parents
004e4e31
4b9c2e03
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
849 additions
and
192 deletions
+849
-192
src/SanaModel.cpp
src/SanaModel.cpp
+20
-15
src/SanaModel.h
src/SanaModel.h
+5
-4
src/Serialization.cpp
src/Serialization.cpp
+2
-0
src/Tensor.h
src/Tensor.h
+4
-1
src/common.h
src/common.h
+10
-0
src/interop/torch.cpp
src/interop/torch.cpp
+4
-2
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+4
-2
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+102
-38
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+93
-0
src/kernels/zgemm/gemm_w4a4.cu
src/kernels/zgemm/gemm_w4a4.cu
+22
-5
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+412
-28
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+7
-2
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
+157
-91
src/kernels/zgemm/gemm_w8a8.cu
src/kernels/zgemm/gemm_w8a8.cu
+2
-2
src/kernels/zgemm/zgemm.h
src/kernels/zgemm/zgemm.h
+5
-2
No files found.
src/SanaModel.cpp
View file @
b1b44398
...
@@ -8,11 +8,11 @@
...
@@ -8,11 +8,11 @@
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
dtype
,
device
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
use_fp4
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
use_fp4
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
pag_to_v
(
std
::
nullopt
)
{
{
registerChildren
registerChildren
...
@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S
...
@@ -21,7 +21,7 @@ SanaLinearAttention::SanaLinearAttention(int dim, bool bias, bool pag, Tensor::S
;
;
if
(
pag
)
{
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
dtype
,
device
);
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
use_fp4
,
dtype
,
device
);
registerChildren
(
pag_to_v
.
value
(),
"pag_to_v"
);
registerChildren
(
pag_to_v
.
value
(),
"pag_to_v"
);
}
}
}
}
...
@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
...
@@ -63,7 +63,11 @@ Tensor SanaLinearAttention::forward(Tensor x, Tensor out) {
qkv_proj
.
wscales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
);
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
,
qkv_proj
.
use_fp4
,
*
qkv_proj
.
wtscale
.
data_ptr
<
float
>
(),
qkv_proj
.
wcscales
.
numel
()
>
0
?
qkv_proj
.
wcscales
:
Tensor
{}
);
debug
(
"vk"
,
vk
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
debug
(
"q"
,
q
);
...
@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
...
@@ -121,11 +125,11 @@ Tensor SanaLinearAttention::forward_pag(Tensor x, bool cfg) {
return
out
;
return
out
;
}
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
)
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
use_fp4
,
dtype
,
device
)
{
{
registerChildren
registerChildren
(
q_linear
,
"q_linear"
)
(
q_linear
,
"q_linear"
)
...
@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
...
@@ -173,11 +177,11 @@ Tensor MultiHeadCrossAttention::forward(Tensor x, Tensor cond, Tensor cu_seqlens
return
out_proj
.
forward
(
attn_output
);
return
out_proj
.
forward
(
attn_output
);
}
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
dtype
,
device
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
use_fp4
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
dtype
,
device
)
point_conv
(
hidden_features
,
in_features
,
false
,
use_fp4
,
dtype
,
device
)
{
{
registerChildren
registerChildren
(
inverted_conv
,
"inverted_conv"
)
(
inverted_conv
,
"inverted_conv"
)
...
@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
...
@@ -200,11 +204,11 @@ Tensor SanaGLUMBConv::forward(Tensor x, int H, int W) {
return
point_conv
.
forward_quant
(
qact
);
return
point_conv
.
forward_quant
(
qact
);
}
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
dtype
,
device
),
attn
(
hidden_size
,
false
,
pag
,
use_fp4
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
use_fp4
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
use_fp4
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
{
...
@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
...
@@ -313,6 +317,7 @@ SanaModel::SanaModel(SanaConfig config, Tensor::ScalarType dtype, Device device)
ceilDiv
(
int
(
round
(
config
.
expand_ratio
*
inner_dim
)),
64
)
*
64
,
ceilDiv
(
int
(
round
(
config
.
expand_ratio
*
inner_dim
)),
64
)
*
64
,
config
.
num_cross_attention_heads
,
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
config
.
use_fp4
,
dtype
,
device
dtype
,
device
));
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
...
...
src/SanaModel.h
View file @
b1b44398
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
class
SanaLinearAttention
:
public
Module
{
class
SanaLinearAttention
:
public
Module
{
public:
public:
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
out
=
{});
Tensor
forward
(
Tensor
x
,
Tensor
out
=
{});
Tensor
forward_pag
(
Tensor
x
,
bool
cfg
);
Tensor
forward_pag
(
Tensor
x
,
bool
cfg
);
...
@@ -25,7 +25,7 @@ private:
...
@@ -25,7 +25,7 @@ private:
class
MultiHeadCrossAttention
:
public
Module
{
class
MultiHeadCrossAttention
:
public
Module
{
public:
public:
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
);
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
);
Tensor
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
);
...
@@ -41,7 +41,7 @@ private:
...
@@ -41,7 +41,7 @@ private:
class
SanaGLUMBConv
:
public
Module
{
class
SanaGLUMBConv
:
public
Module
{
public:
public:
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
int
H
,
int
W
);
Tensor
forward
(
Tensor
x
,
int
H
,
int
W
);
...
@@ -57,7 +57,7 @@ private:
...
@@ -57,7 +57,7 @@ private:
class
SanaLinearTransformerBlock
:
public
Module
{
class
SanaLinearTransformerBlock
:
public
Module
{
public:
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
bool
use_fp4
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
...
@@ -83,6 +83,7 @@ struct SanaConfig {
...
@@ -83,6 +83,7 @@ struct SanaConfig {
int
num_cross_attention_heads
;
int
num_cross_attention_heads
;
double
expand_ratio
;
double
expand_ratio
;
std
::
vector
<
int
>
pag_layers
;
std
::
vector
<
int
>
pag_layers
;
bool
use_fp4
;
};
};
class
SanaModel
:
public
Module
{
class
SanaModel
:
public
Module
{
...
...
src/Serialization.cpp
View file @
b1b44398
...
@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() {
...
@@ -117,6 +117,8 @@ void SafeTensors::parseHeader() {
{
"I8"
,
Tensor
::
INT8
},
{
"I8"
,
Tensor
::
INT8
},
{
"I32"
,
Tensor
::
INT32
},
{
"I32"
,
Tensor
::
INT32
},
{
"I64"
,
Tensor
::
INT64
},
{
"I64"
,
Tensor
::
INT64
},
{
"F8_E4M3"
,
Tensor
::
FP8_E4M3
},
{
"F8_E5M2"
,
Tensor
::
FP8_E5M2
},
};
};
auto
check
=
[](
bool
cond
,
std
::
source_location
location
=
std
::
source_location
::
current
())
{
auto
check
=
[](
bool
cond
,
std
::
source_location
location
=
std
::
source_location
::
current
())
{
...
...
src/Tensor.h
View file @
b1b44398
...
@@ -218,7 +218,8 @@ public:
...
@@ -218,7 +218,8 @@ public:
enum
ScalarType
{
enum
ScalarType
{
INVALID_SCALAR_TYPE
,
INVALID_SCALAR_TYPE
,
INT8
,
INT16
,
INT32
,
INT64
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
FP16
,
FP32
,
BF16
,
FP8_E4M3
,
FP8_E5M2
,
};
};
struct
TensorOptions
{
struct
TensorOptions
{
...
@@ -546,6 +547,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
...
@@ -546,6 +547,8 @@ inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{
FP16
,
2
},
{
FP16
,
2
},
{
FP32
,
4
},
{
FP32
,
4
},
{
BF16
,
2
},
{
BF16
,
2
},
{
FP8_E4M3
,
1
},
{
FP8_E5M2
,
1
},
};
};
struct
TensorsProvider
{
struct
TensorsProvider
{
...
...
src/common.h
View file @
b1b44398
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <memory>
#include <memory>
#include <source_location>
#include <source_location>
#include <vector>
#include <vector>
#include <list>
#include <stack>
#include <stack>
#include <map>
#include <map>
#include <unordered_map>
#include <unordered_map>
...
@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) {
...
@@ -79,6 +80,15 @@ constexpr T ceilDiv(T a, T b) {
return
(
a
+
b
-
1
)
/
b
;
return
(
a
+
b
-
1
)
/
b
;
}
}
template
<
typename
T
>
constexpr
int
log2Up
(
T
value
)
{
if
(
value
<=
0
)
return
0
;
if
(
value
==
1
)
return
0
;
return
log2Up
((
value
+
1
)
/
2
)
+
1
;
}
struct
CUBLASWrapper
{
struct
CUBLASWrapper
{
cublasHandle_t
handle
=
nullptr
;
cublasHandle_t
handle
=
nullptr
;
...
...
src/interop/torch.cpp
View file @
b1b44398
...
@@ -28,8 +28,9 @@ Tensor from_torch(at::Tensor input) {
...
@@ -28,8 +28,9 @@ Tensor from_torch(at::Tensor input) {
{
at
::
ScalarType
::
Float
,
Tensor
::
FP32
},
{
at
::
ScalarType
::
Float
,
Tensor
::
FP32
},
{
at
::
ScalarType
::
Half
,
Tensor
::
FP16
},
{
at
::
ScalarType
::
Half
,
Tensor
::
FP16
},
{
at
::
ScalarType
::
BFloat16
,
Tensor
::
BF16
},
{
at
::
ScalarType
::
BFloat16
,
Tensor
::
BF16
},
{
at
::
ScalarType
::
Short
,
Tensor
::
INT16
},
{
at
::
ScalarType
::
Short
,
Tensor
::
INT16
},
{
at
::
ScalarType
::
Float8_e4m3fn
,
Tensor
::
FP8_E4M3
},
{
at
::
ScalarType
::
Float8_e5m2
,
Tensor
::
FP8_E5M2
},
};
};
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
...
@@ -55,8 +56,9 @@ at::Tensor to_torch(Tensor input) {
...
@@ -55,8 +56,9 @@ at::Tensor to_torch(Tensor input) {
{
Tensor
::
FP32
,
at
::
ScalarType
::
Float
},
{
Tensor
::
FP32
,
at
::
ScalarType
::
Float
},
{
Tensor
::
FP16
,
at
::
ScalarType
::
Half
},
{
Tensor
::
FP16
,
at
::
ScalarType
::
Half
},
{
Tensor
::
BF16
,
at
::
ScalarType
::
BFloat16
},
{
Tensor
::
BF16
,
at
::
ScalarType
::
BFloat16
},
{
Tensor
::
INT16
,
at
::
ScalarType
::
Short
},
{
Tensor
::
INT16
,
at
::
ScalarType
::
Short
},
{
Tensor
::
FP8_E4M3
,
at
::
ScalarType
::
Float8_e4m3fn
},
{
Tensor
::
FP8_E5M2
,
at
::
ScalarType
::
Float8_e5m2
},
};
};
c10
::
TensorOptions
opts
(
mapType
.
at
(
input
.
scalar_type
()));
c10
::
TensorOptions
opts
(
mapType
.
at
(
input
.
scalar_type
()));
...
...
src/kernels/awq/gemv_awq.cu
View file @
b1b44398
...
@@ -140,8 +140,10 @@ __global__ void gemv_kernel(
...
@@ -140,8 +140,10 @@ __global__ void gemv_kernel(
for
(
int
i
=
0
;
i
<
Num
;
++
i
)
for
(
int
i
=
0
;
i
<
Num
;
++
i
)
psum
[
i
]
=
static_cast
<
accum_t
>
(
0.
f
);
psum
[
i
]
=
static_cast
<
accum_t
>
(
0.
f
);
extern
__shared__
uint8_t
shmem
[];
// extern __shared__ uint8_t shmem[];
float
(
*
out_smem
)[
Num
*
kInterleave
]
=
reinterpret_cast
<
float
(
*
)[
Num
*
kInterleave
]
>
(
shmem
);
// float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
__shared__
float
out_smem
[
BlockSize
/
WARP_SIZE
*
2
][
Num
*
kInterleave
];
const
int
blk_row_offset
=
blockIdx
.
x
*
NPerBlock
*
kInterleave
;
const
int
blk_row_offset
=
blockIdx
.
x
*
NPerBlock
*
kInterleave
;
const
int
thd_row_offset
=
(
threadIdx
.
x
/
kThreadsNumPerTile
)
%
kInterleave
;
const
int
thd_row_offset
=
(
threadIdx
.
x
/
kThreadsNumPerTile
)
%
kInterleave
;
...
...
src/kernels/zgemm/gemm_base.cuh
View file @
b1b44398
...
@@ -319,10 +319,10 @@ public:
...
@@ -319,10 +319,10 @@ public:
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
if
(
pred
)
{
//
if (pred) {
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
out
[
i
]
=
load
(
&
act
[((
k
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
]);
out
[
i
]
=
load
_pred
(
&
act
[((
k
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
]
,
pred
);
}
//
}
}
}
}
}
...
@@ -336,12 +336,12 @@ public:
...
@@ -336,12 +336,12 @@ public:
// int offset = K / WARP_K * WARP_SIZE;
// int offset = K / WARP_K * WARP_SIZE;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_N_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_N_TILES
;
i
++
)
{
if
(
pred
)
{
//
if (pred) {
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
// out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
out
[
i
]
=
load
(
&
ptr
[
i
*
WARP_SIZE
]);
out
[
i
]
=
load
_pred
(
&
ptr
[
i
*
WARP_SIZE
]
,
pred
);
// ptr += offset;
// ptr += offset;
}
//
}
}
}
}
}
...
@@ -352,11 +352,11 @@ public:
...
@@ -352,11 +352,11 @@ public:
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ASCALES_NUM_PACKS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ASCALES_NUM_PACKS
;
i
++
)
{
if
(
pred
&&
laneId
<
ASCALES_VALID_LANES
)
{
//
if (pred && laneId < ASCALES_VALID_LANES) {
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
// out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
out
[
i
]
=
ascales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
+
i
*
ASCALES_VALID_LANES
+
laneId
];
out
[
i
]
=
load_pred
(
&
ascales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
+
i
*
ASCALES_VALID_LANES
+
laneId
]
,
pred
&&
laneId
<
ASCALES_VALID_LANES
)
;
}
//
}
}
}
}
}
...
@@ -373,13 +373,13 @@ public:
...
@@ -373,13 +373,13 @@ public:
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WSCALES_NUM_PACKS
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WSCALES_NUM_PACKS
;
i
++
)
{
if
(
pred
&&
laneId
<
WSCALES_VALID_LANES
)
{
//
if (pred && laneId < WSCALES_VALID_LANES) {
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
// out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
out
[
i
]
=
load
(
&
wscales
[(
group
*
WSCALES_NUM_PACKS
+
i
)
*
WSCALES_VALID_LANES
+
laneId
]);
out
[
i
]
=
load
_pred
(
&
wscales
[(
group
*
WSCALES_NUM_PACKS
+
i
)
*
WSCALES_VALID_LANES
+
laneId
]
,
pred
&&
laneId
<
WSCALES_VALID_LANES
);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
// out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
}
//
}
}
}
}
}
...
@@ -400,7 +400,7 @@ public:
...
@@ -400,7 +400,7 @@ public:
return
__shfl_sync
(
~
0
,
block
[
packIdx
].
data
[
elementIdx
],
srcLane
);
return
__shfl_sync
(
~
0
,
block
[
packIdx
].
data
[
elementIdx
],
srcLane
);
}
}
template
<
typename
F
>
template
<
bool
FAST_I2F
=
false
,
typename
F
>
__device__
__forceinline__
__device__
__forceinline__
static
void
apply_scales
(
F
&&
getpsum
,
ascale_warp
ascale
,
wscale_warp
wscale
,
fpsum_warp
&
fpsum
)
{
static
void
apply_scales
(
F
&&
getpsum
,
ascale_warp
ascale
,
wscale_warp
wscale
,
fpsum_warp
&
fpsum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
...
@@ -429,12 +429,31 @@ public:
...
@@ -429,12 +429,31 @@ public:
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
// }
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
0
]),
__int2float_rn
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
2
]),
__int2float_rn
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
4
]),
__int2float_rn
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
6
]),
__int2float_rn
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
auto
scale_fma_normal
=
[
&
]()
ALWAYSINLINE
{
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
0
]),
__int2float_rn
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
2
]),
__int2float_rn
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
4
]),
__int2float_rn
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
__int2float_rn
(
psum
.
data
[
6
]),
__int2float_rn
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
};
// should be faster on sm_80
auto
scale_fma_fast
=
[
&
]()
ALWAYSINLINE
{
fsum
.
data
[
0
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
0
]),
int2float_fast
(
psum
.
data
[
1
]))),
__hmul2
(
asx
[
i
],
ws1
),
fsum
.
data
[
0
]);
fsum
.
data
[
1
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
2
]),
int2float_fast
(
psum
.
data
[
3
]))),
__hmul2
(
asy
[
i
],
ws1
),
fsum
.
data
[
1
]);
fsum
.
data
[
2
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
4
]),
int2float_fast
(
psum
.
data
[
5
]))),
__hmul2
(
asx
[
i
],
ws2
),
fsum
.
data
[
2
]);
fsum
.
data
[
3
]
=
__hfma2
(
float22half2
<
half2_t
>
(
make_float2
(
int2float_fast
(
psum
.
data
[
6
]),
int2float_fast
(
psum
.
data
[
7
]))),
__hmul2
(
asy
[
i
],
ws2
),
fsum
.
data
[
3
]);
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ <= 800
if
constexpr
(
FAST_I2F
)
{
scale_fma_fast
();
}
else
{
scale_fma_normal
();
}
#else
scale_fma_normal
();
#endif
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// if (threadIdx.x == 3 && j == 1 && i == 0) {
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
// }
// }
...
@@ -575,9 +594,9 @@ public:
...
@@ -575,9 +594,9 @@ public:
(
plugins
(
i
*
INSN_M
+
row
,
pack
),
...);
(
plugins
(
i
*
INSN_M
+
row
,
pack
),
...);
bool
pred
=
i
*
INSN_M
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
bool
pred
=
i
*
INSN_M
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
//
if (pred) {
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
store
_pred
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
,
pred
);
}
//
}
}
}
__syncwarp
();
__syncwarp
();
...
@@ -602,9 +621,9 @@ public:
...
@@ -602,9 +621,9 @@ public:
(
plugins
(
i
*
INSN_M
+
8
+
row
,
pack
),
...);
(
plugins
(
i
*
INSN_M
+
8
+
row
,
pack
),
...);
bool
pred
=
i
*
INSN_M
+
8
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
bool
pred
=
i
*
INSN_M
+
8
+
row
<
maxRows
&&
laneId
*
PACK_SIZE
<
maxCols
;
if
(
pred
)
{
//
if (pred) {
store
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
8
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
);
store
_pred
(
reinterpret_cast
<
pack_t
*>
(
&
output
[(
i
*
INSN_M
+
8
+
row
)
*
stride
+
laneId
*
PACK_SIZE
]),
pack
,
pred
);
}
//
}
}
}
__syncwarp
();
__syncwarp
();
}
}
...
@@ -680,33 +699,61 @@ public:
...
@@ -680,33 +699,61 @@ public:
}
}
};
};
template
<
bool
USE_BIAS
=
true
,
bool
USE_SCALE
=
false
>
struct
EpilogueBias
{
struct
EpilogueBias
{
struct
Arguments
{
struct
Arguments
{
const
packed_wscale_t
*
bias
;
// [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const
packed_wscale_t
*
bias
;
// [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
const
packed_wscale_t
*
scale
;
};
};
__device__
__forceinline__
__device__
__forceinline__
void
apply_bias
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
)
{
void
apply_bias
(
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
const
packed_wscale_t
*
bias
,
const
packed_wscale_t
*
scale
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// if (laneId == 0) {
// if (laneId == 0) {
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
// }
// }
wscale_warp
b
;
wscale_warp
b
,
s
;
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
if
constexpr
(
USE_BIAS
)
{
load_wscale
(
bias
,
0
,
N
,
b
,
true
);
}
if
constexpr
(
USE_SCALE
)
{
load_wscale
(
scale
,
0
,
N
,
s
,
true
);
}
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
half2_t
b1
=
broadcast_wscale
(
b
,
j
*
4
,
laneId
);
half2_t
b1
,
b2
;
half2_t
b2
=
broadcast_wscale
(
b
,
j
*
4
+
2
,
laneId
);
half2_t
s1
,
s2
;
if
constexpr
(
USE_BIAS
)
{
b1
=
broadcast_wscale
(
b
,
j
*
4
,
laneId
);
b2
=
broadcast_wscale
(
b
,
j
*
4
+
2
,
laneId
);
}
if
constexpr
(
USE_SCALE
)
{
s1
=
broadcast_wscale
(
s
,
j
*
4
,
laneId
);
s2
=
broadcast_wscale
(
s
,
j
*
4
+
2
,
laneId
);
}
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
auto
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
auto
&
fsum
=
fpsum
[
i
*
WARP_N_TILES
+
j
];
fsum
.
data
[
0
]
=
__hadd2
(
fsum
.
data
[
0
],
b1
);
if
constexpr
(
USE_SCALE
&&
USE_BIAS
)
{
fsum
.
data
[
1
]
=
__hadd2
(
fsum
.
data
[
1
],
b1
);
fsum
.
data
[
0
]
=
__hfma2
(
fsum
.
data
[
0
],
s1
,
b1
);
fsum
.
data
[
2
]
=
__hadd2
(
fsum
.
data
[
2
],
b2
);
fsum
.
data
[
1
]
=
__hfma2
(
fsum
.
data
[
1
],
s1
,
b1
);
fsum
.
data
[
3
]
=
__hadd2
(
fsum
.
data
[
3
],
b2
);
fsum
.
data
[
2
]
=
__hfma2
(
fsum
.
data
[
2
],
s2
,
b2
);
fsum
.
data
[
3
]
=
__hfma2
(
fsum
.
data
[
3
],
s2
,
b2
);
}
else
if
constexpr
(
USE_SCALE
)
{
fsum
.
data
[
0
]
=
__hmul2
(
fsum
.
data
[
0
],
s1
);
fsum
.
data
[
1
]
=
__hmul2
(
fsum
.
data
[
1
],
s1
);
fsum
.
data
[
2
]
=
__hmul2
(
fsum
.
data
[
2
],
s2
);
fsum
.
data
[
3
]
=
__hmul2
(
fsum
.
data
[
3
],
s2
);
}
else
if
constexpr
(
USE_BIAS
)
{
fsum
.
data
[
0
]
=
__hadd2
(
fsum
.
data
[
0
],
b1
);
fsum
.
data
[
1
]
=
__hadd2
(
fsum
.
data
[
1
],
b1
);
fsum
.
data
[
2
]
=
__hadd2
(
fsum
.
data
[
2
],
b2
);
fsum
.
data
[
3
]
=
__hadd2
(
fsum
.
data
[
3
],
b2
);
}
}
}
}
}
}
}
...
@@ -714,10 +761,13 @@ public:
...
@@ -714,10 +761,13 @@ public:
__device__
__forceinline__
__device__
__forceinline__
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
void
operator
()(
const
BlockInfo
binfo
,
fpsum_warp
&
fpsum
,
int
M
,
int
N
,
int
K
,
Arguments
args
)
{
const
int
bn
=
binfo
.
bn
;
const
int
bn
=
binfo
.
bn
;
apply_bias
(
if
constexpr
(
USE_BIAS
||
USE_SCALE
)
{
fpsum
,
M
,
N
,
K
,
apply_bias
(
args
.
bias
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
fpsum
,
M
,
N
,
K
,
);
args
.
bias
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
,
args
.
scale
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
}
}
}
};
};
...
@@ -797,7 +847,21 @@ public:
...
@@ -797,7 +847,21 @@ public:
using typename Base::unpack_fpsum; \
using typename Base::unpack_fpsum; \
using typename Base::EpilogueDefault; \
using typename Base::EpilogueDefault; \
using typename Base::EpilogueNop; \
using typename Base::EpilogueNop; \
using typename Base::EpilogueBias;
template<bool USE_BIAS, bool USE_SCALE> \
using EpilogueBias = typename Base::EpilogueBias<USE_BIAS, USE_SCALE>; \
using Base::mma_f16xf16_f32; \
using Base::packed_fp32_to_fp16; \
using Base::packed_fp16_to_fp32; \
using Base::load_act; \
using Base::load_wgt; \
using Base::load_ascale; \
using Base::load_wscale; \
using Base::broadcast_wscale; \
using Base::broadcast_ascale; \
using Base::apply_scales; \
using Base::pack_ascales; \
using Base::pack_wscales; \
using Base::apply_act;
template
<
typename
kernel
,
typename
...
T
>
template
<
typename
kernel
,
typename
...
T
>
...
...
src/kernels/zgemm/gemm_utils.cuh
View file @
b1b44398
...
@@ -43,6 +43,41 @@ static T load(const T *addr) {
...
@@ -43,6 +43,41 @@ static T load(const T *addr) {
return
*
addr
;
return
*
addr
;
}
}
template
<
typename
T
>
__device__
__forceinline__
static
T
load_pred
(
const
T
*
addr
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %2, 0;"
"@loadpred ld.global.nc.b32 %0, [%1];"
"}"
:
"=r"
(
data
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %3, 0;"
"@loadpred ld.global.nc.v2.b32 {%0, %1}, [%2];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
;
asm
volatile
(
"{ .reg .pred loadpred; setp.ne.b32 loadpred, %5, 0;"
"@loadpred ld.global.nc.v4.b32 {%0, %1, %2, %3}, [%4];"
"}"
:
"=r"
(
data
.
x
),
"=r"
(
data
.
y
),
"=r"
(
data
.
z
),
"=r"
(
data
.
w
)
:
"l"
(
addr
),
"r"
((
int
)
pred
));
return
*
reinterpret_cast
<
T
*>
(
&
data
);
}
T
result
;
if
(
pred
)
{
result
=
*
addr
;
}
return
result
;
}
template
<
bool
shmem
=
false
,
typename
T
>
template
<
bool
shmem
=
false
,
typename
T
>
__device__
__forceinline__
__device__
__forceinline__
static
void
store
(
T
*
addr
,
T
val
)
{
static
void
store
(
T
*
addr
,
T
val
)
{
...
@@ -76,6 +111,39 @@ static void store(T *addr, T val) {
...
@@ -76,6 +111,39 @@ static void store(T *addr, T val) {
*
addr
=
val
;
*
addr
=
val
;
}
}
template
<
typename
T
>
__device__
__forceinline__
static
void
store_pred
(
T
*
addr
,
T
val
,
bool
pred
)
{
if
constexpr
(
sizeof
(
T
)
==
4
)
{
uint32_t
data
=
*
reinterpret_cast
<
uint32_t
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.b32 [%1], %2;"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
8
)
{
uint2
data
=
*
reinterpret_cast
<
uint2
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v2.b32 [%1], {%2, %3};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
));
return
;
}
if
constexpr
(
sizeof
(
T
)
==
16
)
{
uint4
data
=
*
reinterpret_cast
<
uint4
*>
(
&
val
);
asm
volatile
(
"{ .reg .pred storepred; setp.ne.b32 storepred, %0, 0;"
"@storepred st.global.cg.v4.b32 [%1], {%2, %3, %4, %5};"
"}"
::
"r"
((
int
)
pred
),
"l"
(
addr
),
"r"
(
data
.
x
),
"r"
(
data
.
y
),
"r"
(
data
.
z
),
"r"
(
data
.
w
));
return
;
}
if
(
pred
)
{
*
addr
=
val
;
}
}
__device__
__forceinline__
__device__
__forceinline__
static
float2
half22float2
(
half2
val
)
{
static
float2
half22float2
(
half2
val
)
{
return
__half22float2
(
val
);
return
__half22float2
(
val
);
...
@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) {
...
@@ -159,6 +227,21 @@ uint32_t quantize_float2<8, false>(float2 value) {
return
result
;
return
result
;
}
}
__device__
__forceinline__
uint32_t
quantize_float2_fp4
(
float2
value
)
{
uint32_t
result
;
asm
volatile
(
"{ .reg .b8 tmp; cvt.rn.satfinite.e2m1x2.f32 tmp, %1, %2; cvt.u32.u8 %0, tmp; }"
:
"=r"
(
result
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
return
result
;
}
__device__
__forceinline__
uint32_t
quantize_float4_fp8
(
float4
value
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
lo
)
:
"f"
(
value
.
y
),
"f"
(
value
.
x
));
asm
volatile
(
"cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
:
"=h"
(
hi
)
:
"f"
(
value
.
w
),
"f"
(
value
.
z
));
return
uint32_t
(
lo
)
|
(
uint32_t
(
hi
)
<<
16
);
}
__device__
__forceinline__
__device__
__forceinline__
static
float
cuda_tanhf
(
float
x
)
{
static
float
cuda_tanhf
(
float
x
)
{
float
result
;
float
result
;
...
@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) {
...
@@ -271,4 +354,14 @@ static void unrolled_loop(F &&lambda) {
call
(
std
::
make_integer_sequence
<
int
,
cnt
>
());
call
(
std
::
make_integer_sequence
<
int
,
cnt
>
());
}
}
// int2float is slow on sm_80 and before
// val in [-4194304, 4194303]
__device__
__forceinline__
static
float
int2float_fast
(
int
val
)
{
float
fval
;
// fval = (val & 0x7FFFFF) ^ 0x4B400000
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;"
:
"=f"
(
fval
)
:
"r"
(
val
),
"n"
(
0x7FFFFF
),
"n"
(
0x4B400000
),
"n"
((
0xF0
&
0xCC
)
^
0xAA
));
return
fval
-
12582912.0
f
;
}
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4.cu
View file @
b1b44398
...
@@ -36,9 +36,23 @@ void gemm_w4a4(
...
@@ -36,9 +36,23 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
)
{
)
{
invoke_launch
(
ascales
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
Tensor
::
ScalarType
dtype
=
Tensor
::
INVALID_SCALAR_TYPE
;
if
(
!
fp4
)
{
dtype
=
ascales
.
dtype
();
}
else
{
for
(
auto
tensor
:
{
out
,
bias
,
lora_up
,
lora_down
,
poolout
,
wcscales
})
{
if
(
tensor
.
valid
())
{
assert
(
dtype
==
Tensor
::
INVALID_SCALAR_TYPE
||
dtype
==
tensor
.
dtype
());
dtype
=
tensor
.
dtype
();
}
}
}
invoke_launch
(
dtype
,
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
act
,
act
,
wgt
,
wgt
,
...
@@ -61,7 +75,10 @@ void gemm_w4a4(
...
@@ -61,7 +75,10 @@ void gemm_w4a4(
out_linearattn
,
out_linearattn
,
act_unsigned
,
act_unsigned
,
lora_scales
,
lora_scales
,
fuse_silu
fuse_silu
,
fp4
,
alpha
,
wcscales
);
);
});
});
}
}
...
@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
...
@@ -72,10 +89,10 @@ void linearattn_vk_mul_q(Tensor q, Tensor vk) {
});
});
}
}
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
,
fp4
);
);
});
});
}
}
...
...
src/kernels/zgemm/gemm_w4a4.cuh
View file @
b1b44398
#pragma once
#pragma once
#include "gemm_base.cuh"
#include "gemm_base.cuh"
// #include "gemm_w4a4_block.cuh"
namespace
nunchaku
::
kernels
{
namespace
nunchaku
::
kernels
{
...
@@ -19,6 +20,369 @@ class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
...
@@ -19,6 +20,369 @@ class GEMM_W4A4<GEMMConfig_W4A4_FP16> : public GEMMBase<GEMMConfig_W4A4_FP16> {
public:
public:
IMPORT_GEMM_BASE
(
Config
);
IMPORT_GEMM_BASE
(
Config
);
public:
// micro-scales for FP4 MMA
// each uint32_t is a 4*32 matrix of scales (for MMA of 64*32)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1200
static
constexpr
bool
FP4_AVAILABLE
=
true
;
#else
static
constexpr
bool
FP4_AVAILABLE
=
false
;
#endif
__device__
__forceinline__
static
void
trap_no_fp4
()
{
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"FP4 is not available on this device
\n
"
);
}
__syncthreads
();
__nanosleep
(
1000000
);
__trap
();
}
static_assert
(
WARP_N
%
32
==
0
);
static_assert
(
WARP_M
%
32
==
0
);
static
constexpr
int
WMSCALES_PACK_SIZE
=
clamp
(
WARP_N
/
32
,
1
,
4
);
static
constexpr
int
WMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_N
/
32
,
WMSCALES_PACK_SIZE
);
static
constexpr
int
WMSCALES_VALID_LANES
=
WARP_SIZE
;
static
constexpr
int
AMSCALES_PACK_SIZE
=
clamp
(
WARP_M
/
32
,
1
,
4
);
static
constexpr
int
AMSCALES_NUM_PACKS
=
ceilDiv
(
WARP_M
/
32
,
AMSCALES_PACK_SIZE
);
static
constexpr
int
AMSCALES_VALID_LANES
=
WARP_SIZE
;
struct
packed_wmscale_t
{
uint32_t
data
[
WMSCALES_PACK_SIZE
];
};
struct
packed_amscale_t
{
uint32_t
data
[
AMSCALES_PACK_SIZE
];
};
using
amscale_warp
=
std
::
array
<
packed_amscale_t
,
AMSCALES_NUM_PACKS
>
;
using
wmscale_warp
=
std
::
array
<
packed_wmscale_t
,
WMSCALES_NUM_PACKS
>
;
// amscales: [M / BLOCK_M, K / group size, NUM_WARPS, AMSCALES_NUM_PACKS, WARP_SIZE] of packed_amscale_t
__device__
__forceinline__
static
void
load_amscale
(
const
packed_amscale_t
*
ptr
,
int
group
,
amscale_warp
&
out
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
AMSCALES_NUM_PACKS
;
i
++
)
{
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
+
i
*
AMSCALES_VALID_LANES
+
laneId
],
pred
);
}
}
// wmscales: [N / BLOCK_N, 1, K / group size, WMSCALES_NUM_PACKS, WMSCALES_VALID_LANES] of packed_wmscale_t
__device__
__forceinline__
static
void
load_wmscale
(
const
packed_wmscale_t
*
ptr
,
int
group
,
wmscale_warp
&
out
,
bool
pred
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
WMSCALES_NUM_PACKS
;
i
++
)
{
out
[
i
]
=
load_pred
(
&
ptr
[(
group
*
WMSCALES_NUM_PACKS
+
i
)
*
WMSCALES_VALID_LANES
+
laneId
],
pred
);
}
}
__device__
__forceinline__
static
void
quantize_w4a4_fp4_from_fpsum_warp
(
const
packed_fpsum_t
(
&
fpsum
)[
INSN_K
/
INSN_N
],
packed_act_t
&
output
,
uint32_t
&
output_scale
,
int
ida
)
{
constexpr
int
NUM_GROUPS
=
4
;
static_assert
(
NUM_GROUPS
==
INSN_K
/
INSN_N
);
constexpr
float
QVALUE_MAX
=
6.0
f
;
constexpr
float
RECPI_QVALUE_MAX
=
1
/
QVALUE_MAX
;
constexpr
float
MSCALE_MAX
=
448.0
f
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
// 0 for row 0-7; 1 for row 8-15
// each half2_t represents a 8*8 matrix
half2_t
input
[
2
][
INSN_K
/
INSN_N
*
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_N
;
i
++
)
{
input
[
0
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
0
];
input
[
0
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
2
];
input
[
1
][
i
*
2
+
0
]
=
fpsum
[
i
].
data
[
1
];
input
[
1
][
i
*
2
+
1
]
=
fpsum
[
i
].
data
[
3
];
}
auto
maxabs
=
[](
half2_t
val
)
ALWAYSINLINE
{
val
=
__habs2
(
val
);
return
__hmax
(
val
.
x
,
val
.
y
);
};
// each half_t represents maxvalue in a 8*16 matrix
half_t
maxvalue
[
2
][
NUM_GROUPS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxabs
(
input
[
0
][
i
*
2
]),
maxabs
(
input
[
0
][
i
*
2
+
1
]));
maxvalue
[
1
][
i
]
=
__hmax
(
maxabs
(
input
[
1
][
i
*
2
]),
maxabs
(
input
[
1
][
i
*
2
+
1
]));
}
#pragma unroll
for
(
int
mask
=
2
;
mask
>
0
;
mask
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
maxvalue
[
0
][
i
]
=
__hmax
(
maxvalue
[
0
][
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
0
][
i
],
mask
));
maxvalue
[
1
][
i
]
=
__hmax
(
maxvalue
[
1
][
i
],
__shfl_xor_sync
(
~
0
,
maxvalue
[
1
][
i
],
mask
));
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical maxvalue now
float
scale
[
2
][
NUM_GROUPS
];
float
rscale
[
2
][
NUM_GROUPS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_GROUPS
;
i
++
)
{
scale
[
0
][
i
]
=
fminf
(
float
(
maxvalue
[
0
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
scale
[
1
][
i
]
=
fminf
(
float
(
maxvalue
[
1
][
i
])
*
RECPI_QVALUE_MAX
,
MSCALE_MAX
);
// TODO: check whether (1 / scale) or (1 / fp8scale) is better
rscale
[
0
][
i
]
=
cuda_frcp
(
scale
[
0
][
i
]);
rscale
[
1
][
i
]
=
cuda_frcp
(
scale
[
1
][
i
]);
}
uint32_t
fp8scale
[
2
];
fp8scale
[
0
]
=
quantize_float4_fp8
(
make_float4
(
scale
[
0
][
0
],
scale
[
0
][
1
],
scale
[
0
][
2
],
scale
[
0
][
3
]));
fp8scale
[
1
]
=
quantize_float4_fp8
(
make_float4
(
scale
[
1
][
0
],
scale
[
1
][
1
],
scale
[
1
][
2
],
scale
[
1
][
3
]));
/**
* output_scale pack format: (ida=0)
* lane 0 => row 0 if ida==0
* lane 1 => row 8 if ida==0
* lane 2 => row 0 if ida==1
* lane 3 => row 8 if ida==1
* ...
* lane i => quad (i/4) => row (i/4+8*(i%2)) if (i%4/2==ida) => srclane i, index i%2
*/
if
(
laneId
%
4
/
2
==
ida
)
{
output_scale
=
(
laneId
%
2
==
0
)
?
fp8scale
[
0
]
:
fp8scale
[
1
];
}
uint32_t
qpacks
[
2
][
INSN_K
/
INSN_M
*
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
float2
fval
=
half22float2
(
input
[
j
][
i
])
*
make_float2
(
rscale
[
j
][
i
/
2
],
rscale
[
j
][
i
/
2
]);
qpacks
[
j
][
i
]
=
quantize_float2_fp4
(
fval
)
<<
(
laneId
%
4
*
8
);
}
}
#pragma unroll
for
(
int
mask
=
1
;
mask
<=
2
;
mask
*=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
INSN_K
/
INSN_M
*
2
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
qpacks
[
j
][
i
]
|=
__shfl_xor_sync
(
~
0
,
qpacks
[
j
][
i
],
mask
);
}
}
}
// lane 0,1,2,3 / 4,5,6,7 / ... should have identical qpacks now
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
if
(
laneId
%
4
==
i
)
{
output
.
x
=
qpacks
[
0
][
0
+
i
];
output
.
y
=
qpacks
[
1
][
0
+
i
];
output
.
z
=
qpacks
[
0
][
4
+
i
];
output
.
w
=
qpacks
[
1
][
4
+
i
];
}
}
}
// m16n16k64 MMA
// ida, idb in {0, 1}
__device__
__forceinline__
static
packed_f32psum_t
mma_fp4
(
packed_act_t
act
,
packed_wgt_t
wgt
,
packed_f32psum_t
psum
,
uint32_t
amscale
,
uint32_t
wmscale
,
int
ida
,
int
idb
)
{
packed_f32psum_t
out
;
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
0
]),
"=f"
(
out
.
data
[
1
]),
"=f"
(
out
.
data
[
2
]),
"=f"
(
out
.
data
[
3
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
x
),
"r"
(
wgt
.
y
),
"f"
(
psum
.
data
[
0
]),
"f"
(
psum
.
data
[
1
]),
"f"
(
psum
.
data
[
2
]),
"f"
(
psum
.
data
[
3
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
))
);
asm
volatile
(
"mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3 "
"{%0, %1, %2, %3}, "
"{%4, %5, %6, %7}, "
"{%8, %9}, "
"{%10, %11, %12, %13}, "
"{%14}, {%15, %16}, "
"{%17}, {%18, %19};"
:
"=f"
(
out
.
data
[
4
]),
"=f"
(
out
.
data
[
5
]),
"=f"
(
out
.
data
[
6
]),
"=f"
(
out
.
data
[
7
])
:
"r"
(
act
.
x
),
"r"
(
act
.
y
),
"r"
(
act
.
z
),
"r"
(
act
.
w
),
"r"
(
wgt
.
z
),
"r"
(
wgt
.
w
),
"f"
(
psum
.
data
[
4
]),
"f"
(
psum
.
data
[
5
]),
"f"
(
psum
.
data
[
6
]),
"f"
(
psum
.
data
[
7
]),
"r"
(
amscale
),
"n"
(
0
),
"h"
((
short
)
ida
),
"r"
(
wmscale
),
"n"
(
0
),
"h"
((
short
)(
idb
*
2
+
1
))
);
return
out
;
}
__device__
__forceinline__
static
void
compute_fp4
(
act_warp
A
,
wgt_warp
W
,
amscale_warp
amscale
,
wmscale_warp
wmscale
,
f32psum_warp
&
psum
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#pragma unroll
for
(
int
j
=
0
;
j
<
WARP_N_TILES
;
j
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
psum
[
i
*
WARP_N_TILES
+
j
]
=
mma_fp4
(
A
[
i
],
W
[
j
],
psum
[
i
*
WARP_N_TILES
+
j
],
amscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
wmscale
[
j
/
2
/
WMSCALES_PACK_SIZE
].
data
[
j
/
2
%
WMSCALES_PACK_SIZE
],
i
%
2
,
j
%
2
);
}
}
}
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
__device__
__forceinline__
static
void
gemm_w4a4_fp4_block
(
const
BlockInfo
binfo
,
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
// per-tensor scale of weight
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
alwaysfalse
)
{
constexpr
int
NUM_STAGES
=
2
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
act_warp
A
[
NUM_STAGES
];
// 8 * 2
wgt_warp
W
[
NUM_STAGES
];
// 32 * 2
amscale_warp
amscale
[
NUM_STAGES
];
// 1 * 2
wmscale_warp
wmscale
[
NUM_STAGES
];
// 4 * 2
f32psum_warp
fpsum
;
// 128
for
(
int
k
=
0
;
k
<
NUM_STAGES
-
1
;
k
++
)
{
load_act
(
act
,
k
,
K
,
A
[
k
],
true
);
load_wgt
(
wgt
,
k
,
K
,
W
[
k
],
true
);
load_amscale
(
ascales
,
k
,
amscale
[
k
],
true
);
load_wmscale
(
wscales
,
k
,
wmscale
[
k
],
true
);
}
#pragma unroll
for
(
auto
&
pack
:
fpsum
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
=
0
;
}
}
int
dummy
=
0
;
for
(
int
k1
=
0
;
k1
<
K
/
WARP_K
;
k1
+=
NUM_STAGES
)
{
#pragma unroll
for
(
int
k2
=
0
;
k2
<
NUM_STAGES
;
k2
++
)
{
int
nextk
=
k1
+
k2
+
NUM_STAGES
-
1
;
int
idx
=
(
k2
+
NUM_STAGES
-
1
)
%
NUM_STAGES
;
bool
pred
=
nextk
<
K
/
WARP_K
;
load_act
(
act
,
nextk
,
K
,
A
[
idx
],
pred
);
load_wgt
(
wgt
,
nextk
,
K
,
W
[
idx
],
pred
);
load_amscale
(
ascales
,
nextk
,
amscale
[
idx
],
pred
);
load_wmscale
(
wscales
,
nextk
,
wmscale
[
idx
],
pred
);
// __syncthreads();
// if (alwaysfalse) {
// dummy = clock();
// }
compute_fp4
(
A
[
k2
],
W
[
k2
],
amscale
[
k2
],
wmscale
[
k2
],
fpsum
);
if
(
alwaysfalse
)
{
dummy
=
clock
();
}
// asm volatile ("membar.cta;");
}
}
unused_var
(
dummy
,
alwaysfalse
);
if
constexpr
(
USE_ALPHA
)
{
#pragma unroll
for
(
auto
&
pack
:
fpsum
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
pack
.
data
[
i
]
*=
alpha
;
}
}
}
auto
f16psum
=
packed_fp32_to_fp16
(
fpsum
);
CHECK_NAN
(
f16psum
,
"f16psum"
);
Epilogue
()(
binfo
,
f16psum
,
M
,
N
,
K
,
epilogueArgs
);
}
template
<
typename
Epilogue
,
bool
USE_ALPHA
>
struct
gemm_w4a4_fp4_kernel
{
__device__
void
operator
()(
const
packed_act_t
*
act
,
const
packed_wgt_t
*
wgt
,
const
packed_amscale_t
*
ascales
,
const
packed_wmscale_t
*
wscales
,
float
alpha
,
int
M
,
int
N
,
int
K
,
Epilogue
::
Arguments
epilogueArgs
,
bool
swapBlockXY
,
bool
alwaysfalse
)
{
BlockInfo
binfo
=
{
.
bm
=
(
int
)
blockIdx
.
x
,
.
bn
=
(
int
)
blockIdx
.
y
,
.
numBlocksM
=
(
int
)
gridDim
.
x
,
.
numBlocksN
=
(
int
)
gridDim
.
y
,
};
if
(
swapBlockXY
)
{
std
::
swap
(
binfo
.
bm
,
binfo
.
bn
);
std
::
swap
(
binfo
.
numBlocksM
,
binfo
.
numBlocksN
);
}
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
if
constexpr
(
FP4_AVAILABLE
)
{
gemm_w4a4_fp4_block
<
Epilogue
,
USE_ALPHA
>
(
binfo
,
act
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
wgt
+
bn
*
(
K
/
WARP_K
)
*
WARP_N_TILES
*
WARP_SIZE
,
ascales
+
bm
*
(
K
/
WARP_K
)
*
NUM_WARPS
*
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
,
wscales
+
bn
*
(
K
/
WARP_K
)
*
WMSCALES_NUM_PACKS
*
WMSCALES_VALID_LANES
,
alpha
,
M
,
N
,
K
,
epilogueArgs
,
alwaysfalse
);
}
else
{
trap_no_fp4
();
}
}
};
public:
public:
template
<
bool
ACT_UNSIGNED
>
template
<
bool
ACT_UNSIGNED
>
__device__
__forceinline__
__device__
__forceinline__
...
@@ -416,7 +780,7 @@ public:
...
@@ -416,7 +780,7 @@ public:
template
<
bool
ACT_UNSIGNED
,
typename
T
>
template
<
bool
ACT_UNSIGNED
,
typename
T
>
__device__
__forceinline__
__device__
__forceinline__
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
static
void
compute
(
act_warp
A
,
wgt_warp
W
,
ascale_warp
ascale
,
wscale_warp
wscale
,
T
&
fpsum
)
{
apply_scales
([
&
](
int
i
,
int
j
)
{
apply_scales
<
true
>
([
&
](
int
i
,
int
j
)
{
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
return
mma
<
ACT_UNSIGNED
>
(
A
[
i
],
W
[
j
]);
},
ascale
,
wscale
,
fpsum
);
},
ascale
,
wscale
,
fpsum
);
}
}
...
@@ -530,6 +894,10 @@ public:
...
@@ -530,6 +894,10 @@ public:
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
#if 0
fpsum_warp fpsum;
GEMM_W4A4_Block<Config>()(act, wgt, ascales, wscales, K, fpsum, alwaysfalse);
#else
act_warp
A
[
NUM_STAGES
];
// 8
act_warp
A
[
NUM_STAGES
];
// 8
wgt_warp
W
[
NUM_STAGES
];
// 32
wgt_warp
W
[
NUM_STAGES
];
// 32
ascale_warp
ascale
[
NUM_STAGES
];
// 1
ascale_warp
ascale
[
NUM_STAGES
];
// 1
...
@@ -591,6 +959,8 @@ public:
...
@@ -591,6 +959,8 @@ public:
unused_var
(
dummy
,
alwaysfalse
);
unused_var
(
dummy
,
alwaysfalse
);
#endif
#if 0
#if 0
auto f16psum = packed_fp32_to_fp16(fpsum);
auto f16psum = packed_fp32_to_fp16(fpsum);
#else
#else
...
@@ -602,11 +972,13 @@ public:
...
@@ -602,11 +972,13 @@ public:
Epilogue
()(
binfo
,
f16psum
,
M
,
N
,
K
,
epilogueArgs
);
Epilogue
()(
binfo
,
f16psum
,
M
,
N
,
K
,
epilogueArgs
);
}
}
template
<
bool
FUSE_GELU
,
bool
USE_UNSIGNED
>
template
<
bool
FUSE_GELU
,
bool
USE_UNSIGNED
,
bool
USE_FP4
>
struct
EpilogueQuantize
{
struct
EpilogueQuantize
{
using
oscales_t
=
typename
std
::
conditional_t
<
USE_FP4
,
packed_amscale_t
,
packed_ascale_t
>
;
struct
Arguments
{
struct
Arguments
{
packed_act_t
*
qout
;
packed_act_t
*
qout
;
packed_a
scale_t
*
oscales
;
o
scale
s
_t
*
oscales
;
half_t
shift_value
;
half_t
shift_value
;
const
packed_wscale_t
*
smooth_factor
;
const
packed_wscale_t
*
smooth_factor
;
...
@@ -616,7 +988,7 @@ public:
...
@@ -616,7 +988,7 @@ public:
static
constexpr
int
NUM_GROUPS
=
WARP_N_TILES
/
NUM_PACKS
;
static
constexpr
int
NUM_GROUPS
=
WARP_N_TILES
/
NUM_PACKS
;
__device__
__forceinline__
__device__
__forceinline__
void
apply_quantize
(
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
packed_act_t
*
qout
,
packed_a
scale_t
*
oscales
,
half_t
shift_value
,
const
packed_wscale_t
*
smooth_factor
)
{
void
apply_quantize
(
fpsum_warp
fpsum
,
int
M
,
int
N
,
int
K
,
packed_act_t
*
qout
,
o
scale
s
_t
*
oscales
,
half_t
shift_value
,
const
packed_wscale_t
*
smooth_factor
)
{
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -627,6 +999,8 @@ public:
...
@@ -627,6 +999,8 @@ public:
#pragma unroll
#pragma unroll
for
(
int
group
=
0
;
group
<
NUM_GROUPS
;
group
++
)
{
for
(
int
group
=
0
;
group
<
NUM_GROUPS
;
group
++
)
{
amscale_warp
omscale
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
for
(
int
i
=
0
;
i
<
WARP_M_TILES
;
i
++
)
{
packed_fpsum_t
tmp
[
NUM_PACKS
];
packed_fpsum_t
tmp
[
NUM_PACKS
];
...
@@ -652,15 +1026,6 @@ public:
...
@@ -652,15 +1026,6 @@ public:
// dst = src;
// dst = src;
}
}
// auto h2div = [](half2_t a, half2_t b) ALWAYSINLINE {
// float2 af = half22float2(a);
// float2 bf = half22float2(b);
// float2 of;
// of.x = __fdividef(af.x, bf.x);
// of.y = __fdividef(af.y, bf.y);
// return float22half2<half2_t>(of);
// };
tmp
[
j
].
data
[
0
]
=
h2div
(
tmp
[
j
].
data
[
0
],
ws1
);
tmp
[
j
].
data
[
0
]
=
h2div
(
tmp
[
j
].
data
[
0
],
ws1
);
tmp
[
j
].
data
[
1
]
=
h2div
(
tmp
[
j
].
data
[
1
],
ws1
);
tmp
[
j
].
data
[
1
]
=
h2div
(
tmp
[
j
].
data
[
1
],
ws1
);
tmp
[
j
].
data
[
2
]
=
h2div
(
tmp
[
j
].
data
[
2
],
ws2
);
tmp
[
j
].
data
[
2
]
=
h2div
(
tmp
[
j
].
data
[
2
],
ws2
);
...
@@ -668,13 +1033,26 @@ public:
...
@@ -668,13 +1033,26 @@ public:
}
}
packed_act_t
qresult
;
packed_act_t
qresult
;
quantize_w4a4_from_fpsum_warp
<
USE_UNSIGNED
>
(
tmp
,
qresult
,
&
oscale_shmem
[
warpId
][
i
*
INSN_M
]);
if
constexpr
(
USE_FP4
)
{
quantize_w4a4_fp4_from_fpsum_warp
(
tmp
,
qresult
,
omscale
[
i
/
2
/
AMSCALES_PACK_SIZE
].
data
[
i
/
2
%
AMSCALES_PACK_SIZE
],
i
%
2
);
}
else
{
quantize_w4a4_from_fpsum_warp
<
USE_UNSIGNED
>
(
tmp
,
qresult
,
&
oscale_shmem
[
warpId
][
i
*
INSN_M
]);
}
store
(
&
qout
[((
group
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
],
qresult
);
store
(
&
qout
[((
group
*
NUM_WARPS
+
warpId
)
*
WARP_M_TILES
+
i
)
*
WARP_SIZE
+
laneId
],
qresult
);
}
}
__syncwarp
();
if
constexpr
(
USE_FP4
)
{
pack_ascales
(
&
oscale_shmem
[
warpId
][
0
],
&
oscales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
#pragma unroll
__syncwarp
();
for
(
int
k
=
0
;
k
<
AMSCALES_NUM_PACKS
;
k
++
)
{
store
(
&
oscales
[((
group
*
NUM_WARPS
+
warpId
)
*
AMSCALES_NUM_PACKS
+
k
)
*
AMSCALES_VALID_LANES
+
laneId
],
omscale
[
k
]);
}
}
if
constexpr
(
!
USE_FP4
)
{
__syncwarp
();
pack_ascales
(
&
oscale_shmem
[
warpId
][
0
],
&
oscales
[(
group
*
NUM_WARPS
+
warpId
)
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
]);
__syncwarp
();
}
}
}
}
}
...
@@ -683,13 +1061,18 @@ public:
...
@@ -683,13 +1061,18 @@ public:
const
int
bm
=
binfo
.
bm
;
const
int
bm
=
binfo
.
bm
;
const
int
bn
=
binfo
.
bn
;
const
int
bn
=
binfo
.
bn
;
apply_quantize
(
if
constexpr
(
!
USE_FP4
||
FP4_AVAILABLE
)
{
fpsum
,
M
,
N
,
K
,
apply_quantize
(
args
.
qout
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
fpsum
,
M
,
N
,
K
,
args
.
oscales
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
,
args
.
qout
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
WARP_M_TILES
*
WARP_SIZE
,
args
.
shift_value
,
args
.
oscales
+
(
bm
*
N
/
WARP_K
+
bn
*
NUM_GROUPS
)
*
NUM_WARPS
*
args
.
smooth_factor
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
(
USE_FP4
?
AMSCALES_NUM_PACKS
*
AMSCALES_VALID_LANES
:
ASCALES_NUM_PACKS
*
ASCALES_VALID_LANES
),
);
args
.
shift_value
,
args
.
smooth_factor
+
bn
*
WSCALES_NUM_PACKS
*
WSCALES_VALID_LANES
);
}
else
{
trap_no_fp4
();
}
}
}
};
};
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
// using EpilogueQuantizeFuseGelu = EpilogueQuantize<true>;
...
@@ -937,8 +1320,10 @@ public:
...
@@ -937,8 +1320,10 @@ public:
}
}
};
};
template
<
bool
fuse_glu
>
template
<
bool
fuse_glu
,
bool
use_fp4
>
struct
quantize_w4a4_fuse_lora_kernel
{
struct
quantize_w4a4_fuse_lora_kernel
{
using
oscales_t
=
typename
std
::
conditional_t
<
use_fp4
,
packed_amscale_t
,
packed_ascale_t
>
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_PER_WARP
=
ceilDiv
<
size_t
>
(
load_act_to_fpsum
<
fuse_glu
>::
SHMEM_SIZE
,
128
)
*
128
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
static
constexpr
size_t
SHMEM_SIZE
=
SHMEM_PER_WARP
*
NUM_WARPS
;
...
@@ -946,7 +1331,7 @@ public:
...
@@ -946,7 +1331,7 @@ public:
const
half_t
*
input
;
const
half_t
*
input
;
const
packed_wscale_t
*
smooth_factor
;
const
packed_wscale_t
*
smooth_factor
;
packed_act_t
*
output
;
packed_act_t
*
output
;
packed_a
scale_t
*
oscales
;
o
scale
s
_t
*
oscales
;
const
packed_fpsum_t
*
lora_wgt_down
;
const
packed_fpsum_t
*
lora_wgt_down
;
float
*
lora_act
;
float
*
lora_act
;
...
@@ -999,7 +1384,7 @@ public:
...
@@ -999,7 +1384,7 @@ public:
.
lora_act
=
args
.
lora_act
,
.
lora_act
=
args
.
lora_act
,
});
});
EpilogueQuantize
<
false
,
false
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
>::
Arguments
{
EpilogueQuantize
<
false
,
false
,
use_fp4
>
()(
binfo
,
fpsum
,
args
.
M
,
args
.
N
,
0
,
typename
EpilogueQuantize
<
false
,
false
,
use_fp4
>::
Arguments
{
.
qout
=
args
.
output
,
.
qout
=
args
.
output
,
.
oscales
=
args
.
oscales
,
.
oscales
=
args
.
oscales
,
.
shift_value
=
0
,
.
shift_value
=
0
,
...
@@ -1488,7 +1873,6 @@ public:
...
@@ -1488,7 +1873,6 @@ public:
);
);
}
}
};
};
};
};
};
// namespace nunchaku::kernels
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4_launch.cuh
View file @
b1b44398
...
@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
...
@@ -12,6 +12,8 @@ class GEMM_W4A4_Launch {
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_amscale_t
=
typename
GEMM
::
packed_amscale_t
;
using
packed_wmscale_t
=
typename
GEMM
::
packed_wmscale_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
using
half_t
=
typename
GEMM
::
half_t
;
...
@@ -38,9 +40,12 @@ public:
...
@@ -38,9 +40,12 @@ public:
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
);
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
src/kernels/zgemm/gemm_w4a4_launch_impl.cuh
View file @
b1b44398
...
@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
...
@@ -30,7 +30,10 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
// packed ws [N]
)
{
)
{
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
M
=
act
.
numel
()
/
act
.
shape
[
-
1
];
int
N
=
wgt
.
shape
[
0
];
int
N
=
wgt
.
shape
[
0
];
...
@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
...
@@ -68,58 +71,111 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
std
::
swap
(
grid
.
x
,
grid
.
y
);
std
::
swap
(
grid
.
x
,
grid
.
y
);
}
}
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
// test_sizeof<typename Epilogue::Arguments>();
// test_sizeof<typename Epilogue::Arguments>();
// std::apply([](auto ...args) {
// std::apply([](auto ...args) {
// (test_sizeof<decltype(args)>(), ...);
// (test_sizeof<decltype(args)>(), ...);
// }, args);
// }, args);
using
kernel
=
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
;
// constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;
auto
func
=
invoke_kernel
<
kernel
,
if
constexpr
(
!
USE_FP4
)
{
const
packed_act_t
*
,
dispatchBool
(
act_unsigned
,
[
&
]
<
bool
ACT_UNSIGNED
>
()
{
const
packed_wgt_t
*
,
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_kernel
<
Epilogue
,
ACT_UNSIGNED
>
,
const
packed_ascale_t
*
,
const
packed_act_t
*
,
const
packed_wscale_t
*
,
const
packed_wgt_t
*
,
int
,
int
,
int
,
const
packed_ascale_t
*
,
typename
Epilogue
::
Arguments
,
const
packed_wscale_t
*
,
bool
,
int
,
int
,
int
,
bool
>
;
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
alpha
==
1.0
f
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
if
(
shmem
>=
24
*
1024
)
{
if
constexpr
(
USE_FP4
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
dispatchBool
(
alpha
!=
1.0
f
,
[
&
]
<
bool
USE_ALPHA
>
()
{
assert
(
!
act_unsigned
);
auto
func
=
invoke_kernel
<
typename
GEMM
::
gemm_w4a4_fp4_kernel
<
Epilogue
,
USE_ALPHA
>
,
const
packed_act_t
*
,
const
packed_wgt_t
*
,
const
packed_amscale_t
*
,
const
packed_wmscale_t
*
,
float
,
int
,
int
,
int
,
typename
Epilogue
::
Arguments
,
bool
,
bool
>
;
if
(
shmem
>=
24
*
1024
)
{
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
}
assert
(
ascales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
wscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
act
.
data_ptr
<
packed_act_t
>
(),
wgt
.
data_ptr
<
packed_wgt_t
>
(),
ascales
.
data_ptr
<
packed_amscale_t
>
(),
wscales
.
data_ptr
<
packed_wmscale_t
>
(),
alpha
,
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
return
;
}
}
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
shmem
>>>
(
// if constexpr (USE_FP4 && !FP4_AVAILABLE) {
act
.
data_ptr
<
packed_act_t
>
(),
// throw std::runtime_error("FP4 kernel is not available");
wgt
.
data_ptr
<
packed_wgt_t
>
(),
// }
ascales
.
data_ptr
<
packed_ascale_t
>
(),
wscales
.
data_ptr
<
packed_wscale_t
>
(),
M
,
N
,
K
,
args
,
swapBlockMN
,
false
);
checkCUDA
(
cudaGetLastError
());
});
});
};
};
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
auto
launch_bias
=
[
&
]
<
typename
NextEpilogue
>
(
NextEpilogue
::
Arguments
nextArgs
)
{
if
(
!
bias
.
valid
())
{
assert
(
!
bias
.
valid
()
||
bias
.
numel
()
==
N
);
return
launch
.
template
operator
()
<
NextEpilogue
>(
nextArgs
);
assert
(
!
wcscales
.
valid
()
||
wcscales
.
numel
()
==
N
);
}
dispatchBool
(
bias
.
valid
(),
[
&
]
<
bool
USE_BIAS
>
()
{
assert
(
bias
.
numel
()
==
N
);
dispatchBool
(
wcscales
.
valid
(),
[
&
]
<
bool
USE_SCALE
>
()
{
using
EpilogueBias
=
typename
GEMM
::
EpilogueBias
<
USE_BIAS
,
USE_SCALE
>
;
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
typename
GEMM
::
EpilogueCombination
<
EpilogueBias
,
NextEpilogue
,
typename
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
return
launch
.
template
operator
()
<
Epilogue
>({
typename
GEMM
::
EpilogueBias
::
Arguments
{
typename
EpilogueBias
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
packed_wscale_t
>
(),
.
bias
=
USE_BIAS
?
bias
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
},
.
scale
=
USE_SCALE
?
wcscales
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
nextArgs
,
},
{}
nextArgs
,
{}
});
});
});
});
};
};
// auto launch_bias = launch;
// auto launch_bias = launch;
...
@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
...
@@ -206,29 +262,32 @@ void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
static
constexpr
float
SHIFT_GELU
=
0.171875
f
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
constexpr
bool
USE_UNSIGNED
=
!
USE_FP4
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
,
USE_FP4
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
EpilogueQuantize
::
oscales_t
>
(),
.
shift_value
=
USE_FP4
?
0.0
f
:
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
});
constexpr
bool
USE_UNSIGNED
=
true
;
using
EpilogueQuantize
=
typename
GEMM
::
EpilogueQuantize
<
false
,
USE_UNSIGNED
>
;
auto
argsQuantize
=
typename
EpilogueQuantize
::
Arguments
{
.
qout
=
qout
.
data_ptr
<
packed_act_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
shift_value
=
SHIFT_GELU
,
.
smooth_factor
=
smooth_factor
.
data_ptr
<
packed_wscale_t
>
()
};
// TODO: check if gelu is needed
if
(
out
.
valid
())
{
launch_lora
.
template
operator
()
<
typename
GEMM
::
EpilogueCombination
<
typename
GEMM
::
EpilogueDefault
,
EpilogueQuantize
>,
typename
GEMM
::
EpilogueGelu
>
({
typename
GEMM
::
EpilogueDefault
::
Arguments
{
.
out
=
out
.
data_ptr
<
half_t
>
(),
.
actualM
=
actualM
,
.
actualN
=
actualN
,
},
argsQuantize
},
{});
}
else
{
launch_lora
.
template
operator
()
<
EpilogueQuantize
,
typename
GEMM
::
EpilogueGelu
>(
argsQuantize
,
{});
}
}
else
if
(
out_linearattn
.
valid
())
{
}
else
if
(
out_linearattn
.
valid
())
{
assert
(
out_vk
.
valid
());
assert
(
out_vk
.
valid
());
...
@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
...
@@ -326,7 +385,7 @@ void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
}
}
template
<
typename
Config
>
template
<
typename
Config
>
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
void
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
,
bool
fp4
)
{
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualM
=
input
.
numel
()
/
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
const
int
actualN
=
input
.
shape
[
-
1
];
...
@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
...
@@ -338,8 +397,13 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
assert
(
output
.
shape
[
-
1
]
==
N
/
2
);
// assert(oscales.dtype() == Tensor::FP16);
// assert(oscales.dtype() == Tensor::FP16);
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
if
(
fp4
)
{
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
assert
(
oscales
.
dtype
()
==
Tensor
::
FP8_E4M3
);
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
*
4
);
}
else
{
assert
(
isTypeMatch
<
half_t
>
(
oscales
.
dtype
()));
assert
(
oscales
.
numel
()
==
M
*
N
/
GEMM
::
WARP_K
);
}
const
int
rank
=
lora_down
.
shape
[
1
];
const
int
rank
=
lora_down
.
shape
[
1
];
...
@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
...
@@ -354,30 +418,32 @@ void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchVal
(
rank
,
LoraRanks
(),
[
&
]
<
int
RANK
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
dispatchBool
(
fuse_glu
,
[
&
]
<
bool
FUSE_GLU
>
()
{
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
dispatchBool
(
fp4
,
[
&
]
<
bool
USE_FP4
>
()
{
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
>
;
using
Lora
=
typename
GEMM
::
Lora
<
RANK
>
;
using
kernel
=
typename
Lora
::
quantize_w4a4_fuse_lora_kernel
<
FUSE_GLU
,
USE_FP4
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
auto
func
=
invoke_kernel
<
kernel
,
typename
kernel
::
Arguments
>
;
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
checkCUDA
(
cudaFuncSetAttribute
(
func
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
kernel
::
SHMEM_SIZE
));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
// log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
typename
kernel
::
Arguments
{
func
<<<
grid
,
GEMM
::
WARP_SIZE
*
GEMM
::
NUM_WARPS
,
kernel
::
SHMEM_SIZE
>>>
(
.
input
=
input
.
data_ptr
<
half_t
>
(),
typename
kernel
::
Arguments
{
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
input
=
input
.
data_ptr
<
half_t
>
(),
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
smooth_factor
=
smooth
.
valid
()
?
smooth
.
data_ptr
<
packed_wscale_t
>
()
:
nullptr
,
.
oscales
=
oscales
.
data_ptr
<
packed_ascale_t
>
(),
.
output
=
output
.
data_ptr
<
packed_act_t
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
oscales
=
oscales
.
data_ptr
<
typename
kernel
::
oscales_t
>
(),
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
lora_wgt_down
=
lora_down
.
data_ptr
<
packed_fpsum_t
>
(),
.
M
=
M
,
.
lora_act
=
lora_act_out
.
data_ptr
<
float
>
(),
.
N
=
N
,
.
M
=
M
,
.
actualM
=
actualM
,
.
N
=
N
,
.
actualN
=
actualN
,
.
actualM
=
actualM
,
}
.
actualN
=
actualN
,
);
}
checkCUDA
(
cudaGetLastError
());
);
checkCUDA
(
cudaGetLastError
());
});
});
});
});
});
}
}
...
...
src/kernels/zgemm/gemm_w8a8.cu
View file @
b1b44398
...
@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
...
@@ -100,9 +100,9 @@ void gemm_w8a8(Tensor act, // [M, K]
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
// ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
using
Epilogue
=
GEMM
::
EpilogueCombination
<
GEMM
::
EpilogueBias
<
true
,
false
>
,
NextEpilogue
,
GEMM
::
EpilogueNop
>
;
return
launch
.
template
operator
()
<
Epilogue
>({
return
launch
.
template
operator
()
<
Epilogue
>({
GEMM
::
EpilogueBias
::
Arguments
{
GEMM
::
EpilogueBias
<
true
,
false
>
::
Arguments
{
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
.
bias
=
bias
.
data_ptr
<
GEMM
::
packed_wscale_t
>
(),
},
},
nextArgs
,
nextArgs
,
...
...
src/kernels/zgemm/zgemm.h
View file @
b1b44398
...
@@ -27,11 +27,14 @@ void gemm_w4a4(
...
@@ -27,11 +27,14 @@ void gemm_w4a4(
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
bool
fuse_silu
,
bool
fp4
,
float
alpha
,
Tensor
wcscales
);
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
);
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
=
{},
bool
fuse_glu
=
false
,
bool
fp4
=
false
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment