Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
a9f1b7af
Commit
a9f1b7af
authored
Nov 10, 2024
by
sxtyzhangzk
Committed by
Zhekai Zhang
Nov 10, 2024
Browse files
[major] move bf16 to compiler args; fp16 experiment
parent
d02f26df
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
92 additions
and
5 deletions
+92
-5
nunchaku/csrc/flux.h
nunchaku/csrc/flux.h
+4
-0
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+1
-0
src/FluxModel.cpp
src/FluxModel.cpp
+31
-1
src/FluxModel.h
src/FluxModel.h
+4
-1
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+0
-1
src/kernels/misc_kernels.cu
src/kernels/misc_kernels.cu
+22
-0
src/kernels/misc_kernels.h
src/kernels/misc_kernels.h
+2
-0
src/kernels/misc_kernels_impl.cuh
src/kernels/misc_kernels_impl.cuh
+24
-2
src/kernels/utils.cuh
src/kernels/utils.cuh
+4
-0
No files found.
nunchaku/csrc/flux.h
View file @
a9f1b7af
...
@@ -185,6 +185,10 @@ public:
...
@@ -185,6 +185,10 @@ public:
});
});
}
}
void
forceFP16Attention
(
bool
enable
)
{
Attention
::
setForceFP16
(
net
.
get
(),
enable
);
}
private:
private:
void
checkModel
()
{
void
checkModel
()
{
...
...
nunchaku/csrc/pybind.cpp
View file @
a9f1b7af
...
@@ -26,6 +26,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -26,6 +26,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"stopDebug"
,
&
QuantizedFluxModel
::
stopDebug
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"getDebugResults"
,
&
QuantizedFluxModel
::
getDebugResults
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"setLoraScale"
,
&
QuantizedFluxModel
::
setLoraScale
)
.
def
(
"forceFP16Attention"
,
&
QuantizedFluxModel
::
forceFP16Attention
)
;
;
py
::
class_
<
QuantizedGEMM
>
(
m
,
"QuantizedGEMM"
)
py
::
class_
<
QuantizedGEMM
>
(
m
,
"QuantizedGEMM"
)
// .def(torch::init<>())
// .def(torch::init<>())
...
...
src/FluxModel.cpp
View file @
a9f1b7af
...
@@ -106,7 +106,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
...
@@ -106,7 +106,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
Attention
::
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
)
:
num_heads
(
num_heads
),
dim_head
(
dim_head
)
num_heads
(
num_heads
),
dim_head
(
dim_head
)
,
force_fp16
(
false
)
{
{
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
headmask_type
=
Tensor
::
allocate
({
num_heads
},
Tensor
::
INT32
,
Device
::
cpu
());
for
(
int
i
=
0
;
i
<
num_heads
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_heads
;
i
++
)
{
...
@@ -116,6 +116,8 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
...
@@ -116,6 +116,8 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
}
}
Tensor
Attention
::
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
)
{
Tensor
Attention
::
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
)
{
const
bool
cast_fp16
=
this
->
force_fp16
&&
qkv
.
scalar_type
()
!=
Tensor
::
FP16
;
assert
(
qkv
.
ndims
()
==
3
);
assert
(
qkv
.
ndims
()
==
3
);
const
Device
device
=
qkv
.
device
();
const
Device
device
=
qkv
.
device
();
...
@@ -169,6 +171,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -169,6 +171,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
}
}
}
}
if
(
cast_fp16
)
{
Tensor
tmp
=
Tensor
::
empty
(
qkv
.
shape
.
dataExtent
,
Tensor
::
FP16
,
qkv
.
device
());
cast
(
qkv
,
tmp
);
qkv
=
tmp
;
}
debug
(
"qkv"
,
qkv
);
Tensor
cu_seqlens
=
cu_seqlens_cpu
.
copy
(
device
);
Tensor
cu_seqlens
=
cu_seqlens_cpu
.
copy
(
device
);
Tensor
reshaped
=
qkv
.
view
({
batch_size
*
num_tokens
,
num_heads
*
3
,
dim_head
});
Tensor
reshaped
=
qkv
.
view
({
batch_size
*
num_tokens
,
num_heads
*
3
,
dim_head
});
...
@@ -192,6 +202,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -192,6 +202,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false
,
false
,
false
,
-
1
,
-
1
false
,
false
,
false
,
-
1
,
-
1
).
front
();
).
front
();
debug
(
"raw_attn_output"
,
raw_attn_output
);
if
(
cast_fp16
)
{
Tensor
tmp
=
Tensor
::
empty
(
raw_attn_output
.
shape
.
dataExtent
,
Tensor
::
BF16
,
raw_attn_output
.
device
());
cast
(
raw_attn_output
,
tmp
);
raw_attn_output
=
tmp
;
}
/**
/**
Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
cu_seqlens,
cu_seqlens,
...
@@ -229,6 +247,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
...
@@ -229,6 +247,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
return
raw_attn_output
;
return
raw_attn_output
;
}
}
void
Attention
::
setForceFP16
(
Module
*
module
,
bool
value
)
{
spdlog
::
info
(
"{} force fp16 attention"
,
value
?
"Enable"
:
"Disable"
);
module
->
traverse
([
&
](
Module
*
m
)
{
if
(
Attention
*
attn
=
dynamic_cast
<
Attention
*>
(
m
))
{
attn
->
force_fp16
=
value
;
}
});
}
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
FluxSingleTransformerBlock
::
FluxSingleTransformerBlock
(
int
dim
,
int
num_attention_heads
,
int
attention_head_dim
,
int
mlp_ratio
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim
(
dim
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
dim_head
(
attention_head_dim
/
num_attention_heads
),
...
@@ -250,6 +278,7 @@ FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attentio
...
@@ -250,6 +278,7 @@ FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attentio
(
qkv_proj
,
"qkv_proj"
)
(
qkv_proj
,
"qkv_proj"
)
(
norm_q
,
"norm_q"
)
(
norm_q
,
"norm_q"
)
(
norm_k
,
"norm_k"
)
(
norm_k
,
"norm_k"
)
(
attn
,
"attn"
)
(
out_proj
,
"out_proj"
)
(
out_proj
,
"out_proj"
)
;
;
}
}
...
@@ -328,6 +357,7 @@ JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, i
...
@@ -328,6 +357,7 @@ JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, i
(
norm_k
,
"norm_k"
)
(
norm_k
,
"norm_k"
)
(
norm_added_q
,
"norm_added_q"
)
(
norm_added_q
,
"norm_added_q"
)
(
norm_added_k
,
"norm_added_k"
)
(
norm_added_k
,
"norm_added_k"
)
(
attn
,
"attn"
)
(
out_proj
,
"out_proj"
)
(
out_proj
,
"out_proj"
)
(
out_proj_context
,
"out_proj_context"
)
(
out_proj_context
,
"out_proj_context"
)
(
norm2
,
"norm2"
)
(
norm2
,
"norm2"
)
...
...
src/FluxModel.h
View file @
a9f1b7af
...
@@ -53,16 +53,19 @@ private:
...
@@ -53,16 +53,19 @@ private:
LayerNorm
norm
;
LayerNorm
norm
;
};
};
class
Attention
{
class
Attention
:
public
Module
{
public:
public:
static
constexpr
int
POOL_SIZE
=
128
;
static
constexpr
int
POOL_SIZE
=
128
;
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Attention
(
int
num_heads
,
int
dim_head
,
Device
device
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
Tensor
forward
(
Tensor
qkv
,
Tensor
pool_qkv
,
float
sparsityRatio
);
static
void
setForceFP16
(
Module
*
module
,
bool
value
);
public:
public:
const
int
num_heads
;
const
int
num_heads
;
const
int
dim_head
;
const
int
dim_head
;
bool
force_fp16
;
private:
private:
Tensor
cu_seqlens_cpu
;
Tensor
cu_seqlens_cpu
;
...
...
src/kernels/awq/gemv_awq.cu
View file @
a9f1b7af
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include "gemv_awq.h"
#include "gemv_awq.h"
#include "../dispatch_utils.h"
#include "../dispatch_utils.h"
#define ENABLE_BF16 1
#include "../utils.cuh"
#include "../utils.cuh"
#include <cuda_fp16.h>
#include <cuda_fp16.h>
...
...
src/kernels/misc_kernels.cu
View file @
a9f1b7af
...
@@ -188,6 +188,28 @@ Tensor quant_static_fuse_gelu(Tensor x, float scale) {
...
@@ -188,6 +188,28 @@ Tensor quant_static_fuse_gelu(Tensor x, float scale) {
return
out
;
return
out
;
}
}
void
cast
(
Tensor
input
,
Tensor
output
)
{
assert
(
input
.
is_contiguous
());
assert
(
output
.
is_contiguous
());
assert
(
input
.
shape
.
dataExtent
==
output
.
shape
.
dataExtent
);
auto
stream
=
getCurrentCUDAStream
();
dispatch
(
input
.
scalar_type
(),
[
&
]
<
typename
input_t
>
()
{
dispatch
(
output
.
scalar_type
(),
[
&
]
<
typename
output_t
>
()
{
constexpr
int
unroll
=
16
/
std
::
max
(
sizeof
(
input_t
),
sizeof
(
output_t
));
int
threadsPerBlock
=
1024
;
int
blocksPerGrid
=
(
int
)
ceilDiv
<
int64_t
>
(
input
.
numel
(),
threadsPerBlock
*
unroll
);
cast_kernel
<
input_t
,
output_t
,
unroll
><<<
blocksPerGrid
,
threadsPerBlock
,
0
,
stream
>>>
(
input
.
data_ptr
<
input_t
>
(),
output
.
data_ptr
<
output_t
>
(),
input
.
numel
());
checkCUDA
(
cudaGetLastError
());
});
});
}
Tensor
topk
(
Tensor
x
,
int
k
)
{
Tensor
topk
(
Tensor
x
,
int
k
)
{
constexpr
int
MAXK
=
64
+
4
;
constexpr
int
MAXK
=
64
+
4
;
...
...
src/kernels/misc_kernels.h
View file @
a9f1b7af
...
@@ -11,6 +11,8 @@ void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
...
@@ -11,6 +11,8 @@ void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
Tensor
quant_static
(
Tensor
x
,
float
scale
);
Tensor
quant_static
(
Tensor
x
,
float
scale
);
Tensor
quant_static_fuse_gelu
(
Tensor
x
,
float
scale
);
Tensor
quant_static_fuse_gelu
(
Tensor
x
,
float
scale
);
void
cast
(
Tensor
input
,
Tensor
output
);
Tensor
topk
(
Tensor
x
,
int
k
);
Tensor
topk
(
Tensor
x
,
int
k
);
template
<
size_t
N
>
template
<
size_t
N
>
...
...
src/kernels/misc_kernels_impl.cuh
View file @
a9f1b7af
#include "reduction_utils.cuh"
#include "reduction_utils.cuh"
#include <array>
#include <array>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "utils.cuh"
#include "utils.cuh"
#include "activation_kernels_impl.cuh"
#include "activation_kernels_impl.cuh"
#include <cuda_fp16.h>
template
<
typename
T
>
template
<
typename
T
>
__global__
void
add_kernel
(
T
*
a
,
T
*
b
,
T
*
c
,
size_t
length
)
{
__global__
void
add_kernel
(
T
*
a
,
T
*
b
,
T
*
c
,
size_t
length
)
{
...
@@ -162,7 +164,27 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output,
...
@@ -162,7 +164,27 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output,
*
reinterpret_cast
<
I8vec
*>
(
&
output
[
i
])
=
routput
;
*
reinterpret_cast
<
I8vec
*>
(
&
output
[
i
])
=
routput
;
}
}
#include <cstdio>
template
<
typename
Tin
,
typename
Tout
,
int
unroll
>
__global__
void
cast_kernel
(
const
Tin
*
input
,
Tout
*
output
,
size_t
length
)
{
const
int
i
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
unroll
;
using
Tvec_in
=
::
Tvec
<
Tin
,
unroll
>
;
using
Tvec_out
=
::
Tvec
<
Tout
,
unroll
>
;
Tvec_in
rinput
=
*
reinterpret_cast
<
const
Tvec_in
*>
(
&
input
[
i
]);
Tvec_out
routput
;
#pragma unroll
for
(
int
k
=
0
;
k
<
unroll
;
k
++
)
{
routput
.
data
[
k
]
=
cuda_cast
<
Tout
,
Tin
>
(
rinput
.
data
[
k
]);
if
constexpr
(
std
::
is_same_v
<
Tout
,
half
>
)
{
routput
.
data
[
k
]
=
__hmin
(
routput
.
data
[
k
],
(
half
)
65504
);
routput
.
data
[
k
]
=
__hmax
(
routput
.
data
[
k
],
(
half
)
-
65504
);
}
}
*
reinterpret_cast
<
Tvec_out
*>
(
&
output
[
i
])
=
routput
;
}
// input: [..., N]
// input: [..., N]
// output: [..., K] of index in reverse order
// output: [..., K] of index in reverse order
...
...
src/kernels/utils.cuh
View file @
a9f1b7af
...
@@ -8,6 +8,10 @@
...
@@ -8,6 +8,10 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
template
<
typename
T
>
struct
num_elems
;
template
<
typename
T
>
struct
num_elems
;
template
<
>
struct
num_elems
<
float
>
{
static
constexpr
int
value
=
1
;
};
template
<
>
struct
num_elems
<
float
>
{
static
constexpr
int
value
=
1
;
};
template
<
>
struct
num_elems
<
float2
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
float2
>
{
static
constexpr
int
value
=
2
;
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment