Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
e9ad0535
Commit
e9ad0535
authored
Jan 23, 2025
by
muyangli
Browse files
[major] support SANA
parent
9eb2cee0
Changes
86
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3317 additions
and
145 deletions
+3317
-145
src/SanaModel.cpp
src/SanaModel.cpp
+332
-0
src/SanaModel.h
src/SanaModel.h
+98
-0
src/Tensor.h
src/Tensor.h
+11
-4
src/interop/torch.cpp
src/interop/torch.cpp
+1
-1
src/interop/torch.h
src/interop/torch.h
+4
-0
src/kernels/dispatch_cutlass.h
src/kernels/dispatch_cutlass.h
+19
-0
src/kernels/dwconv.cu
src/kernels/dwconv.cu
+140
-0
src/kernels/dwconv.h
src/kernels/dwconv.h
+3
-1
src/kernels/gemm_batched.cu
src/kernels/gemm_batched.cu
+1
-1
src/kernels/gemm_f16.cu
src/kernels/gemm_f16.cu
+110
-99
src/kernels/gemm_f16.h
src/kernels/gemm_f16.h
+2
-2
src/kernels/gemm_w8a8.cu
src/kernels/gemm_w8a8.cu
+1
-1
src/kernels/misc_kernels.cu
src/kernels/misc_kernels.cu
+62
-6
src/kernels/misc_kernels.h
src/kernels/misc_kernels.h
+7
-1
src/kernels/misc_kernels_impl.cuh
src/kernels/misc_kernels_impl.cuh
+25
-16
src/kernels/zgemm/gemm_base.cuh
src/kernels/zgemm/gemm_base.cuh
+828
-0
src/kernels/zgemm/gemm_utils.cuh
src/kernels/zgemm/gemm_utils.cuh
+30
-13
src/kernels/zgemm/gemm_w4a4.cu
src/kernels/zgemm/gemm_w4a4.cu
+98
-0
src/kernels/zgemm/gemm_w4a4.cuh
src/kernels/zgemm/gemm_w4a4.cuh
+1494
-0
src/kernels/zgemm/gemm_w4a4_launch.cuh
src/kernels/zgemm/gemm_w4a4_launch.cuh
+51
-0
No files found.
src/SanaModel.cpp
0 → 100644
View file @
e9ad0535
#include "SanaModel.h"
#include "kernels/zgemm/zgemm.h"
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/nvToolsExt.h>
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
SanaLinearAttention
::
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dim
(
dim
),
dim_pad
(
ceilDiv
(
dim
,
128
)
*
128
),
qkv_proj
(
dim
,
dim_pad
*
3
,
bias
,
dtype
,
device
),
out_proj
(
dim_pad
,
dim
,
bias
,
dtype
,
device
),
pag_to_v
(
std
::
nullopt
)
{
registerChildren
(
qkv_proj
,
"qkv_proj"
)
(
out_proj
,
"out_proj"
)
;
if
(
pag
)
{
pag_to_v
.
emplace
(
dim
,
dim_pad
,
bias
,
dtype
,
device
);
registerChildren
(
pag_to_v
.
value
(),
"pag_to_v"
);
}
}
Tensor
SanaLinearAttention
::
forward
(
Tensor
x
,
Tensor
out
)
{
constexpr
int
HEAD_DIM
=
32
;
assert
(
x
.
ndims
()
==
3
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens
=
x
.
shape
[
1
];
const
int
num_tokens_pad
=
ceilDiv
(
num_tokens
,
256
)
*
256
;
assert
(
x
.
shape
[
2
]
==
dim
);
const
int
num_heads
=
dim_pad
/
HEAD_DIM
;
if
(
num_tokens_pad
!=
num_tokens
)
{
spdlog
::
debug
(
"SanaLinearAttention: pad num_tokens from {} to {}"
,
num_tokens
,
num_tokens_pad
);
Tensor
x_pad
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim
},
x
.
dtype
(),
x
.
device
());
x_pad
.
zero_
();
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
x_pad
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens
).
copy_
(
x
.
slice
(
0
,
i
,
i
+
1
));
}
x
=
x_pad
;
}
auto
qact
=
qkv_proj
.
quantize
(
x
,
false
);
Tensor
q
=
Tensor
::
allocate
({
batch_size
,
num_tokens_pad
,
dim_pad
},
x
.
dtype
(),
x
.
device
());
Tensor
vk
=
Tensor
::
allocate
({
batch_size
,
num_heads
,
HEAD_DIM
+
1
,
HEAD_DIM
},
Tensor
::
FP32
,
x
.
device
());
kernels
::
gemm_w4a4
(
qact
.
act
,
qkv_proj
.
qweight
,
{},
{},
qact
.
ascales
,
qkv_proj
.
wscales
,
{},
{},
qact
.
lora_act
,
qkv_proj
.
lora_up
,
{},
{},
{},
{},
{},
qkv_proj
.
bias
,
{},
vk
,
q
,
qact
.
is_unsigned
,
qkv_proj
.
lora_scales
,
false
);
debug
(
"vk"
,
vk
);
debug
(
"q"
,
q
);
kernels
::
linearattn_vk_mul_q
(
q
,
vk
);
debug
(
"raw_attn_output"
,
q
);
if
(
num_tokens_pad
!=
num_tokens
)
{
Tensor
q_unpad
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim_pad
},
q
.
dtype
(),
q
.
device
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
q_unpad
.
slice
(
0
,
i
,
i
+
1
).
copy_
(
q
.
slice
(
0
,
i
,
i
+
1
).
slice
(
1
,
0
,
num_tokens
));
}
q
=
q_unpad
;
}
// kernels::gemm_w8a8_fuse_litela(qact.act, qkv.qweight, q, vk, qact.ascales, qkv.wscales);
// return out_proj.forward(q);
if
(
!
out
.
valid
())
{
out
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
},
q
.
dtype
(),
q
.
device
());
}
out_proj
.
forward
(
q
,
out
);
return
out
;
}
Tensor
SanaLinearAttention
::
forward_pag
(
Tensor
x
,
bool
cfg
)
{
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens
=
x
.
shape
[
1
];
Tensor
out
=
Tensor
::
allocate
({
batch_size
,
num_tokens
,
dim
},
x
.
dtype
(),
x
.
device
());
Tensor
x_org
,
x_ptb
;
Tensor
out_org
,
out_ptb
;
if
(
cfg
)
{
assert
(
batch_size
%
3
==
0
);
x_org
=
x
.
slice
(
0
,
0
,
batch_size
*
2
/
3
);
x_ptb
=
x
.
slice
(
0
,
batch_size
*
2
/
3
,
batch_size
);
out_org
=
out
.
slice
(
0
,
0
,
batch_size
*
2
/
3
);
out_ptb
=
out
.
slice
(
0
,
batch_size
*
2
/
3
,
batch_size
);
}
else
{
assert
(
batch_size
%
2
==
0
);
x_org
=
x
.
slice
(
0
,
0
,
batch_size
/
2
);
x_ptb
=
x
.
slice
(
0
,
batch_size
/
2
,
batch_size
);
out_org
=
out
.
slice
(
0
,
0
,
batch_size
/
2
);
out_ptb
=
out
.
slice
(
0
,
batch_size
/
2
,
batch_size
);
}
this
->
forward
(
x_org
,
out_org
);
Tensor
v_ptb
=
this
->
pag_to_v
.
value
().
forward
(
x_ptb
);
this
->
out_proj
.
forward
(
v_ptb
,
out_ptb
);
return
out
;
}
MultiHeadCrossAttention
::
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
num_heads
(
num_heads
),
head_dim
(
head_dim
),
q_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
),
kv_linear
(
num_heads
*
head_dim
,
num_heads
*
head_dim
*
2
,
true
,
dtype
,
device
),
out_proj
(
num_heads
*
head_dim
,
num_heads
*
head_dim
,
true
,
dtype
,
device
)
{
registerChildren
(
q_linear
,
"q_linear"
)
(
kv_linear
,
"kv_linear"
)
(
out_proj
,
"out_proj"
)
;
}
Tensor
MultiHeadCrossAttention
::
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
)
{
assert
(
x
.
ndims
()
==
3
);
assert
(
cond
.
ndims
()
==
2
);
assert
(
cu_seqlens_img
.
ndims
()
==
1
);
assert
(
cu_seqlens_txt
.
ndims
()
==
1
);
const
int
batch_size
=
x
.
shape
[
0
];
const
int
num_tokens_img
=
x
.
shape
[
1
];
const
int
num_tokens_txt
=
cond
.
shape
[
0
];
assert
(
cu_seqlens_img
.
shape
[
0
]
==
batch_size
+
1
);
assert
(
cu_seqlens_txt
.
shape
[
0
]
==
batch_size
+
1
);
Tensor
q
=
q_linear
.
forward
(
x
).
view
({
batch_size
*
num_tokens_img
,
num_heads
,
head_dim
});
Tensor
kv
=
kv_linear
.
forward
(
cond
).
view
({
num_tokens_txt
,
num_heads
*
2
,
head_dim
});
Tensor
k
=
kv
.
slice
(
1
,
0
,
num_heads
);
Tensor
v
=
kv
.
slice
(
1
,
num_heads
,
num_heads
*
2
);
Tensor
attn_output
=
mha_varlen_fwd
(
q
,
k
,
v
,
cu_seqlens_img
,
cu_seqlens_txt
,
num_tokens_img
,
num_tokens_txt
,
0.0
f
,
pow
(
q
.
shape
[
-
1
],
(
-
0.5
)),
false
,
false
,
-
1
,
-
1
,
false
).
front
().
view
({
batch_size
,
num_tokens_img
,
num_heads
*
head_dim
});
// Tensor attn_output = mha_fwd(q, k, v,
// 0.0f,
// pow(q.shape[-1], (-0.5)),
// false, -1, -1, false
// ).front().view({B, N, num_heads * head_dim});
return
out_proj
.
forward
(
attn_output
);
}
SanaGLUMBConv
::
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
in_features
(
in_features
),
hidden_features
(
hidden_features
),
inverted_conv
(
in_features
,
hidden_features
*
2
,
true
,
dtype
,
device
),
depth_conv
(
hidden_features
*
2
,
true
,
dtype
,
device
),
point_conv
(
hidden_features
,
in_features
,
false
,
dtype
,
device
)
{
registerChildren
(
inverted_conv
,
"inverted_conv"
)
(
depth_conv
,
"depth_conv"
)
(
point_conv
,
"point_conv"
)
;
}
Tensor
SanaGLUMBConv
::
forward
(
Tensor
x
,
int
H
,
int
W
)
{
if
(
H
<=
0
||
W
<=
0
)
{
H
=
W
=
sqrt
(
x
.
shape
[
1
]);
}
x
=
inverted_conv
.
forward_silu
(
x
);
x
=
x
.
view
({
x
.
shape
[
0
],
H
,
W
,
x
.
shape
[
-
1
]});
debug
(
"inverted_conv_output"
,
x
);
x
=
depth_conv
.
forward
(
x
);
debug
(
"depth_conv_output"
,
x
);
x
=
x
.
view
({
x
.
shape
[
0
],
H
*
W
,
x
.
shape
[
-
1
]});
auto
qact
=
point_conv
.
quantize
(
x
,
true
);
return
point_conv
.
forward_quant
(
qact
);
}
SanaLinearTransformerBlock
::
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
hidden_size
(
hidden_size
),
num_cross_attention_heads
(
num_cross_attention_heads
),
attn
(
hidden_size
,
false
,
pag
,
dtype
,
device
),
cross_attn
(
num_cross_attention_heads
,
hidden_size
/
num_cross_attention_heads
,
dtype
,
device
),
ff
(
hidden_size
,
intermediate_size
,
dtype
,
device
),
norm1
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
),
norm2
(
hidden_size
,
1e-6
,
false
,
dtype
,
device
)
{
this
->
scale_shift_table
=
Tensor
::
allocate
({
6
,
hidden_size
},
dtype
,
device
);
registerChildren
(
attn
,
"attn"
)
(
cross_attn
,
"cross_attn"
)
(
ff
,
"ff"
)
;
registerParams
(
this
->
scale_shift_table
,
"scale_shift_table"
)
;
}
Tensor
SanaLinearTransformerBlock
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
nvtxRangePushA
(
"SanaLinearTransformerBlock"
);
nvtxRangePushA
(
"chunk"
);
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
const
int
batch_size
=
timestep
.
shape
[
0
];
timestep
=
timestep
.
copy
(
timestep
.
device
());
timestep
=
timestep
.
view
({
batch_size
,
6
,
hidden_size
});
kernels
::
mul_add_batch
(
timestep
,
{},
false
,
0
,
this
->
scale_shift_table
,
false
);
debug
(
"shifted_timestep"
,
timestep
);
std
::
array
<
Tensor
,
6
>
chunked
;
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
chunked
[
i
]
=
timestep
.
slice
(
1
,
i
,
i
+
1
);
}
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
chunked
;
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
nvtxRangePop
();
{
nvtxRangePushA
(
"LinearAttention"
);
Tensor
residual
=
hidden_states
;
Tensor
norm_hidden_states
=
norm1
.
forward
(
hidden_states
);
kernels
::
mul_add_batch
(
norm_hidden_states
,
scale_msa
,
true
,
1
,
shift_msa
,
true
);
debug
(
"norm_hidden_states_la"
,
norm_hidden_states
);
Tensor
attn_output
=
pag
?
attn
.
forward_pag
(
norm_hidden_states
,
cfg
)
:
attn
.
forward
(
norm_hidden_states
);
debug
(
"attn_output_la"
,
attn_output
);
kernels
::
mul_add_batch
(
attn_output
,
gate_msa
,
true
,
0
,
residual
,
true
);
hidden_states
=
attn_output
;
nvtxRangePop
();
}
{
nvtxRangePushA
(
"CrossAttention"
);
debug
(
"norm_hidden_states_cross"
,
hidden_states
);
Tensor
attn_output
=
cross_attn
.
forward
(
hidden_states
,
encoder_hidden_states
,
cu_seqlens_img
,
cu_seqlens_txt
);
debug
(
"attn_output_cross"
,
attn_output
);
kernels
::
mul_add_batch
(
attn_output
,
{},
false
,
0
,
hidden_states
,
true
);
hidden_states
=
attn_output
;
nvtxRangePop
();
}
{
nvtxRangePushA
(
"Feed-forward"
);
debug
(
"hidden_states_ff"
,
hidden_states
);
Tensor
norm_hidden_states
=
norm2
.
forward
(
hidden_states
);
kernels
::
mul_add_batch
(
norm_hidden_states
,
scale_mlp
,
true
,
1
,
shift_mlp
,
true
);
debug
(
"norm_hidden_states_ff"
,
norm_hidden_states
);
Tensor
ff_output
=
ff
.
forward
(
norm_hidden_states
,
H
,
W
);
debug
(
"ff_output"
,
ff_output
);
kernels
::
mul_add_batch
(
ff_output
,
gate_mlp
,
true
,
0
,
hidden_states
,
true
);
hidden_states
=
ff_output
;
nvtxRangePop
();
}
nvtxRangePop
();
debug
(
"hidden_states_out"
,
hidden_states
);
return
hidden_states
;
}
SanaModel
::
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
config
(
config
)
{
const
int
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
;
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
transformer_blocks
.
push_back
(
std
::
make_unique
<
SanaLinearTransformerBlock
>
(
inner_dim
,
ceilDiv
(
int
(
round
(
config
.
expand_ratio
*
inner_dim
)),
64
)
*
64
,
config
.
num_cross_attention_heads
,
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
dtype
,
device
));
registerChildren
(
*
transformer_blocks
.
back
(),
format
(
"transformer_blocks.{}"
,
i
));
}
}
Tensor
SanaModel
::
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
)
{
for
(
int
i
=
0
;
i
<
config
.
num_layers
;
i
++
)
{
auto
&&
block
=
transformer_blocks
[
i
];
hidden_states
=
block
->
forward
(
hidden_states
,
encoder_hidden_states
,
timestep
,
cu_seqlens_img
,
cu_seqlens_txt
,
H
,
W
,
pag
&&
std
::
find
(
config
.
pag_layers
.
begin
(),
config
.
pag_layers
.
end
(),
i
)
!=
config
.
pag_layers
.
end
(),
cfg
);
}
return
hidden_states
;
}
src/SanaModel.h
0 → 100644
View file @
e9ad0535
#pragma once
#include "common.h"
#include "Tensor.h"
#include "Linear.h"
#include "layernorm.h"
class
SanaLinearAttention
:
public
Module
{
public:
SanaLinearAttention
(
int
dim
,
bool
bias
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
out
=
{});
Tensor
forward_pag
(
Tensor
x
,
bool
cfg
);
public:
const
int
dim
;
const
int
dim_pad
;
private:
GEMM_W4A4
qkv_proj
;
GEMM_W4A4
out_proj
;
std
::
optional
<
GEMM_W4A4
>
pag_to_v
;
};
class
MultiHeadCrossAttention
:
public
Module
{
public:
MultiHeadCrossAttention
(
int
num_heads
,
int
head_dim
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
Tensor
cond
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
);
public:
const
int
num_heads
;
const
int
head_dim
;
private:
GEMM_W4A4
q_linear
;
GEMM_F16
kv_linear
;
GEMM_W4A4
out_proj
;
};
class
SanaGLUMBConv
:
public
Module
{
public:
SanaGLUMBConv
(
int
in_features
,
int
hidden_features
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
x
,
int
H
,
int
W
);
public:
const
int
in_features
;
const
int
hidden_features
;
private:
GEMM_W4A4
inverted_conv
;
DWCONV
depth_conv
;
GEMM_W4A4
point_conv
;
};
class
SanaLinearTransformerBlock
:
public
Module
{
public:
SanaLinearTransformerBlock
(
int
hidden_size
,
int
intermediate_size
,
int
num_cross_attention_heads
,
bool
pag
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
public:
const
int
hidden_size
;
const
int
num_cross_attention_heads
;
private:
Tensor
scale_shift_table
;
// Tensor ones;
SanaLinearAttention
attn
;
MultiHeadCrossAttention
cross_attn
;
SanaGLUMBConv
ff
;
LayerNorm
norm1
,
norm2
;
};
struct
SanaConfig
{
int
num_layers
;
int
num_attention_heads
;
int
attention_head_dim
;
int
num_cross_attention_heads
;
double
expand_ratio
;
std
::
vector
<
int
>
pag_layers
;
};
class
SanaModel
:
public
Module
{
public:
SanaModel
(
SanaConfig
config
,
Tensor
::
ScalarType
dtype
,
Device
device
);
Tensor
forward
(
Tensor
hidden_states
,
Tensor
encoder_hidden_states
,
Tensor
timestep
,
Tensor
cu_seqlens_img
,
Tensor
cu_seqlens_txt
,
int
H
,
int
W
,
bool
pag
,
bool
cfg
);
public:
const
SanaConfig
config
;
public:
std
::
vector
<
std
::
unique_ptr
<
SanaLinearTransformerBlock
>>
transformer_blocks
;
};
\ No newline at end of file
src/Tensor.h
View file @
e9ad0535
...
...
@@ -32,6 +32,10 @@ public:
size_t
getSize
()
{
return
size
;
}
Device
getDevice
()
{
return
device
;
}
virtual
bool
isAsyncBuffer
()
{
return
false
;
}
protected:
template
<
typename
Derived
>
std
::
shared_ptr
<
Derived
>
shared_from_base
()
{
...
...
@@ -90,6 +94,9 @@ public:
}
checkCUDA
(
cudaFreeAsync
(
this
->
ptr
,
0
));
}
virtual
bool
isAsyncBuffer
()
override
{
return
true
;
}
};
class
BufferCUDASync
:
public
Buffer
{
...
...
@@ -499,16 +506,16 @@ private:
return
cudaMemcpyDefault
;
}
static
bool
isAsyncBuffer
(
Buffer
*
buffer
)
{
return
dynamic_cast
<
BufferCUDA
*>
(
buffer
);
}
//
static bool isAsyncBuffer(Buffer *buffer) {
//
return dynamic_cast<BufferCUDA *>(buffer);
//
}
static
inline
std
::
map
<
cudaStream_t
,
std
::
set
<
std
::
shared_ptr
<
Buffer
>>>
lockedBuffers
;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU completes
static
void
lockBuffer
(
std
::
shared_ptr
<
Buffer
>
buffer
,
cudaStream_t
stream
)
{
if
(
!
isAsyncBuffer
(
buffer
.
get
()
))
{
if
(
!
buffer
->
isAsyncBuffer
())
{
lockedBuffers
[
stream
].
insert
(
buffer
);
}
}
...
...
src/interop/torch.cpp
View file @
e9ad0535
...
...
@@ -33,7 +33,7 @@ Tensor from_torch(at::Tensor input) {
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
Tensor
::
lockBuffer
(
result
.
buffer
,
getCurrentCUDAStream
());
//
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return
result
;
}
...
...
src/interop/torch.h
View file @
e9ad0535
...
...
@@ -13,6 +13,10 @@ public:
this
->
device
.
type
=
this
->
tensor
.
is_cuda
()
?
Device
::
CUDA
:
Device
::
CPU
;
this
->
device
.
idx
=
this
->
tensor
.
get_device
();
}
virtual
bool
isAsyncBuffer
()
override
{
// TODO: figure out how torch manages memory
return
true
;
}
private:
at
::
Tensor
tensor
;
};
...
...
src/kernels/dispatch_cutlass.h
0 → 100644
View file @
e9ad0535
#pragma once
#include "common.h"
#include "Tensor.h"
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/bfloat16.h>
template
<
typename
F
>
inline
void
dispatchF16
(
Tensor
::
ScalarType
type
,
F
&&
func
)
{
if
(
type
==
Tensor
::
FP16
)
{
func
.
template
operator
()
<
cutlass
::
half_t
>();
}
else
if
(
type
==
Tensor
::
BF16
)
{
func
.
template
operator
()
<
cutlass
::
bfloat16_t
>();
}
else
{
assert
(
false
);
}
}
\ No newline at end of file
src/kernels/dwconv.cu
View file @
e9ad0535
#include "common.h"
#include "Tensor.h"
#include "dispatch_cutlass.h"
#include <cuda_runtime.h>
#include "cutlass/cutlass.h"
...
...
@@ -10,6 +12,7 @@
// depthwise_Conv2d operation cutlass_sm80_tensorop_f16_s16x8x16fprop_analytic_f16_256x128_64x3_nhwc_align8
#if 0
using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;
using FilterShape = cutlass::MatrixShape<3, 3>;
...
...
@@ -194,3 +197,140 @@ Tensor depthwise_conv2d_kernel(Tensor A, Tensor B) {
return D;
}
#endif
Tensor
dwconv_f16
(
Tensor
input
,
Tensor
weight
,
Tensor
out
,
Tensor
bias
)
{
assert
(
input
.
ndims
()
==
4
);
const
int
N
=
input
.
size
(
0
);
const
int
H
=
input
.
size
(
1
);
const
int
W
=
input
.
size
(
2
);
const
int
C_
=
input
.
size
(
3
);
assert
(
weight
.
ndims
()
==
4
);
const
int
K
=
weight
.
size
(
0
);
const
int
R
=
weight
.
size
(
1
);
const
int
S
=
weight
.
size
(
2
);
const
int
C__
=
weight
.
size
(
3
);
// weight = weight.copy(weight.device());
dispatchF16
(
weight
.
dtype
(),
[
&
]
<
typename
half_t
>
()
{
using
ElementOutput
=
half_t
;
using
ElementAccumulator
=
half_t
;
using
ElementComputeEpilogue
=
half_t
;
using
ElementInputA
=
half_t
;
using
ElementInputB
=
half_t
;
using
LayoutInputA
=
cutlass
::
layout
::
TensorNHWC
;
using
LayoutInputB
=
cutlass
::
layout
::
TensorNHWC
;
using
LayoutOutput
=
cutlass
::
layout
::
TensorNHWC
;
using
ThreadBlockOutputShape
=
cutlass
::
conv
::
TensorNHWCShape
<
1
,
8
,
8
,
64
>
;
using
FilterShape
=
cutlass
::
MatrixShape
<
3
,
3
>
;
using
ThreadblockShape
=
cutlass
::
gemm
::
GemmShape
<
ThreadBlockOutputShape
::
kNHW
,
64
,
FilterShape
::
kCount
>
;
using
WarpShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
FilterShape
::
kCount
>
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
DepthwiseDirect2dConv
=
typename
cutlass
::
conv
::
kernel
::
DefaultDepthwiseDirect2dConvFprop
<
ElementInputA
,
LayoutInputA
,
ElementInputB
,
LayoutInputB
,
ElementOutput
,
LayoutOutput
,
ElementAccumulator
,
cutlass
::
arch
::
OpClassSimt
,
cutlass
::
arch
::
Sm80
,
ThreadblockShape
,
ThreadBlockOutputShape
,
FilterShape
,
WarpShape
,
InstructionShape
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOutput
,
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
,
ElementOutput
,
ElementComputeEpilogue
>
,
cutlass
::
conv
::
threadblock
::
DepthwiseDirect2dConvIdentityThreadblockSwizzle
<
1
,
ThreadBlockOutputShape
::
kN
,
ThreadBlockOutputShape
::
kH
,
ThreadBlockOutputShape
::
kW
>
,
4
,
cutlass
::
arch
::
OpMultiplyAdd
,
cutlass
::
conv
::
IteratorAlgorithm
::
kFixedStrideDilation
,
cutlass
::
conv
::
StrideSupport
::
kFixed
,
cutlass
::
MatrixShape
<
1
,
1
>
,
cutlass
::
MatrixShape
<
1
,
1
>>::
Kernel
;
using
DeviceKernel
=
typename
cutlass
::
conv
::
device
::
DirectConvolution
<
DepthwiseDirect2dConv
>
;
cutlass
::
conv
::
Conv2dProblemSize
problem_size
(
cutlass
::
Tensor4DCoord
(
N
,
H
,
W
,
C_
),
cutlass
::
Tensor4DCoord
(
K
,
R
,
S
,
C__
),
cutlass
::
Tensor4DCoord
(
1
,
1
,
1
,
1
),
cutlass
::
MatrixCoord
(
1
,
1
),
cutlass
::
MatrixCoord
(
1
,
1
),
cutlass
::
conv
::
Mode
::
kCrossCorrelation
,
1
,
C_
// groups
);
const
int
P
=
problem_size
.
P
;
const
int
Q
=
problem_size
.
Q
;
if
(
!
out
.
valid
())
{
out
=
Tensor
::
allocate
({
N
,
P
,
Q
,
K
},
input
.
dtype
(),
input
.
device
());
}
assert
(
out
.
ndims
()
==
4
);
assert
(
out
.
size
(
0
)
==
N
);
assert
(
out
.
size
(
1
)
==
P
);
assert
(
out
.
size
(
2
)
==
Q
);
assert
(
out
.
size
(
3
)
==
K
);
Tensor
tmp_weight
=
Tensor
::
empty_like
(
weight
);
cutlass
::
TensorRef
<
ElementInputA
,
LayoutInputA
>
a_ref
(
input
.
data_ptr
<
ElementInputA
>
(),
LayoutInputA
(
input
.
stride
(
2
),
input
.
stride
(
1
),
input
.
stride
(
0
)));
cutlass
::
TensorRef
<
ElementInputB
,
LayoutInputB
>
b_ref
(
weight
.
data_ptr
<
ElementInputB
>
(),
LayoutInputB
(
weight
.
stride
(
2
),
weight
.
stride
(
1
),
weight
.
stride
(
0
)));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
c_ref
(
bias
.
valid
()
?
bias
.
data_ptr
<
ElementOutput
>
()
:
out
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
(
0
,
0
,
0
));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
d_ref
(
out
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
(
out
.
stride
(
2
),
out
.
stride
(
1
),
out
.
stride
(
0
)));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
tmpw_ref
(
tmp_weight
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
(
tmp_weight
.
stride
(
2
),
tmp_weight
.
stride
(
1
),
tmp_weight
.
stride
(
0
)));
typename
DeviceKernel
::
Arguments
arguments
{
problem_size
,
a_ref
,
b_ref
,
c_ref
,
d_ref
,
{
ElementOutput
(
1.0
f
),
ElementOutput
(
bias
.
valid
()
?
1.0
f
:
0.0
f
)},
tmpw_ref
,
};
DeviceKernel
implicit_gemm_op
;
size_t
workspace_size
=
implicit_gemm_op
.
get_workspace_size
(
arguments
);
BufferCUDA
workspace
(
workspace_size
);
auto
stream
=
getCurrentCUDAStream
();
cutlass
::
Status
status
=
implicit_gemm_op
.
can_implement
(
arguments
);
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot implement"
);
}
status
=
implicit_gemm_op
.
initialize
(
arguments
,
workspace
.
getPtr
(),
stream
);
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot initialize"
);
}
status
=
implicit_gemm_op
(
stream
);
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot run"
);
}
});
return
out
;
}
\ No newline at end of file
src/kernels/dwconv.h
View file @
e9ad0535
...
...
@@ -3,4 +3,6 @@
#include "common.h"
#include "Tensor.h"
Tensor
depthwise_conv2d_kernel
(
Tensor
A
,
Tensor
B
);
\ No newline at end of file
// Tensor depthwise_conv2d_kernel(Tensor A, Tensor B);
Tensor
dwconv_f16
(
Tensor
input
,
Tensor
weight
,
Tensor
out
,
Tensor
bias
);
\ No newline at end of file
src/kernels/gemm_batched.cu
View file @
e9ad0535
...
...
@@ -47,7 +47,7 @@ Tensor gemm_batched_fp16(
auto
sizeO
=
cutlass
::
MatrixCoord
(
M
,
N
);
if
(
!
out
.
valid
())
{
auto
outShape
=
a
.
shape
;
auto
outShape
=
TensorShape
(
a
.
shape
.
dataExtent
)
;
outShape
[
-
1
]
=
N
;
out
=
Tensor
::
empty
(
outShape
,
Tensor
::
FP32
,
a
.
device
());
}
...
...
src/kernels/gemm_f16.cu
View file @
e9ad0535
#include "gemm_f16.h"
#include "dispatch_cutlass.h"
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
...
...
@@ -13,8 +15,8 @@ using spdlog::fmt_lib::format;
Tensor
gemm_f16
(
Tensor
input
,
// FP16
Tensor
weight
,
// FP16
Tensor
out
,
// FP16
float
alpha
,
float
bet
a
Tensor
bias
,
float
alph
a
)
{
auto
N
=
weight
.
size
(
0
);
auto
K
=
input
.
size
(
-
1
);
...
...
@@ -23,102 +25,111 @@ Tensor gemm_f16(Tensor input, // FP16
spdlog
::
debug
(
"gemm_f16: M={} K={} N={}"
,
M
,
K
,
N
);
using
ElementOutput
=
cutlass
::
bfloat16_t
;
using
ElementAccumulator
=
float
;
using
ElementComputeEpilogue
=
cutlass
::
bfloat16_t
;
using
ElementInputA
=
cutlass
::
bfloat16_t
;
// <- data type of elements in input matrix A
using
ElementInputB
=
cutlass
::
bfloat16_t
;
// <- data type of elements in input matrix B
using
LayoutInputA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutInputB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
// #if CUDA_ARCH >= 800
using
Gemm
=
cutlass
::
gemm
::
device
::
Gemm
<
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOutput
,
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
,
ElementAccumulator
,
ElementComputeEpilogue
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<>
,
3
>
;
auto
input_size
=
cutlass
::
MatrixCoord
(
M
,
K
);
auto
weight_size
=
cutlass
::
MatrixCoord
(
K
,
N
);
auto
output_size
=
cutlass
::
MatrixCoord
(
M
,
N
);
auto
device
=
input
.
device
();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if
(
!
out
.
valid
())
{
auto
out_shape
=
input
.
shape
;
out_shape
[
-
1
]
=
N
;
out
=
Tensor
::
empty
(
out_shape
,
input
.
scalar_type
(),
input
.
device
());
}
// FIXME: check contiguous of input if dims >= 3
assert
(
input
.
stride
(
-
1
)
==
1
);
// assert(input.is_contiguous());
assert
(
weight
.
is_contiguous
());
assert
(
out
.
dtype
()
==
input
.
scalar_type
());
assert
(
out
.
shape
[
-
1
]
==
N
);
assert
(
out
.
numel
()
/
out
.
shape
[
-
1
]
==
M
);
assert
(
out
.
stride
(
-
1
)
==
1
);
// FIXME: check contiguous of output if dims >= 3
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass
::
gemm
::
GemmCoord
problem_size
(
M
,
N
,
K
);
cutlass
::
TensorRef
<
ElementInputA
,
LayoutInputA
>
input_ref
(
input
.
data_ptr
<
ElementInputA
>
(),
LayoutInputA
(
input
.
stride
(
-
2
)));
cutlass
::
TensorRef
<
ElementInputB
,
LayoutInputB
>
weight_ref
(
weight
.
data_ptr
<
ElementInputB
>
(),
LayoutInputB
::
packed
(
weight_size
));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
out_ref
(
out
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
(
out
.
stride
(
-
2
)));
typename
Gemm
::
Arguments
arguments
{
problem_size
,
// <- problem size of matrix multiplication
input_ref
,
// <- reference to matrix A on device
weight_ref
,
// <- reference to matrix B on device
out_ref
,
// <- reference to matrix C on device
out_ref
,
// <- reference to matrix D on device
{
ElementOutput
(
alpha
),
ElementOutput
(
beta
)},
1
};
Gemm
gemm_op
;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA
workspace
(
workspace_size
);
// Check the problem size is supported or not
cutlass
::
Status
status
=
gemm_op
.
can_implement
(
arguments
);
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
format
(
"cutlass cannot implement M={} N={} K={}"
,
M
,
N
,
K
));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status
=
gemm_op
.
initialize
(
arguments
,
workspace
.
getPtr
());
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot initialize"
);
}
status
=
gemm_op
();
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot run"
);
}
dispatchF16
(
weight
.
dtype
(),
[
&
]
<
typename
half_t
>
()
{
using
ElementOutput
=
half_t
;
using
ElementAccumulator
=
float
;
using
ElementComputeEpilogue
=
half_t
;
using
ElementInputA
=
half_t
;
// <- data type of elements in input matrix A
using
ElementInputB
=
half_t
;
// <- data type of elements in input matrix B
using
LayoutInputA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutInputB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
// #if CUDA_ARCH >= 800
using
Gemm
=
cutlass
::
gemm
::
device
::
Gemm
<
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
cutlass
::
arch
::
OpClassTensorOp
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
,
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOutput
,
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
,
ElementAccumulator
,
ElementComputeEpilogue
>
,
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<>
,
3
>
;
auto
input_size
=
cutlass
::
MatrixCoord
(
M
,
K
);
auto
weight_size
=
cutlass
::
MatrixCoord
(
K
,
N
);
auto
output_size
=
cutlass
::
MatrixCoord
(
M
,
N
);
auto
device
=
input
.
device
();
// use the broadcasted bias as the output
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if
(
!
out
.
valid
())
{
auto
out_shape
=
TensorShape
(
input
.
shape
.
dataExtent
);
out_shape
[
-
1
]
=
N
;
out
=
Tensor
::
empty
(
out_shape
,
input
.
scalar_type
(),
input
.
device
());
}
// FIXME: check contiguous of input if dims >= 3
assert
(
input
.
stride
(
-
1
)
==
1
);
// assert(input.is_contiguous());
assert
(
weight
.
is_contiguous
());
assert
(
out
.
dtype
()
==
input
.
scalar_type
());
assert
(
out
.
shape
[
-
1
]
==
N
);
assert
(
out
.
numel
()
/
out
.
shape
[
-
1
]
==
M
);
assert
(
out
.
stride
(
-
1
)
==
1
);
// FIXME: check contiguous of output if dims >= 3
assert
(
!
bias
.
valid
()
||
(
bias
.
ndims
()
==
1
&&
bias
.
shape
[
0
]
==
N
));
// constexpr int kSparse = Gemm::kSparse;
// How many elements of A are covered per ElementE
// constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// The size of individual meta data
// constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
cutlass
::
gemm
::
GemmCoord
problem_size
(
M
,
N
,
K
);
cutlass
::
TensorRef
<
ElementInputA
,
LayoutInputA
>
input_ref
(
input
.
data_ptr
<
ElementInputA
>
(),
LayoutInputA
(
input
.
stride
(
-
2
)));
cutlass
::
TensorRef
<
ElementInputB
,
LayoutInputB
>
weight_ref
(
weight
.
data_ptr
<
ElementInputB
>
(),
LayoutInputB
::
packed
(
weight_size
));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
bias_ref
(
bias
.
valid
()
?
bias
.
data_ptr
<
ElementOutput
>
()
:
out
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
(
0
));
cutlass
::
TensorRef
<
ElementOutput
,
LayoutOutput
>
out_ref
(
out
.
data_ptr
<
ElementOutput
>
(),
LayoutOutput
(
out
.
stride
(
-
2
)));
typename
Gemm
::
Arguments
arguments
{
problem_size
,
// <- problem size of matrix multiplication
input_ref
,
// <- reference to matrix A on device
weight_ref
,
// <- reference to matrix B on device
bias_ref
,
// <- reference to matrix C on device
out_ref
,
// <- reference to matrix D on device
{
ElementOutput
(
alpha
),
ElementOutput
(
bias
.
valid
()
?
1.0
f
:
0.0
f
)},
1
};
Gemm
gemm_op
;
// Using the arguments, query for extra workspace required for matrix
// multiplication computation
size_t
workspace_size
=
Gemm
::
get_workspace_size
(
arguments
);
// Allocate workspace memory
// cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
BufferCUDA
workspace
(
workspace_size
);
// Check the problem size is supported or not
cutlass
::
Status
status
=
gemm_op
.
can_implement
(
arguments
);
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
format
(
"cutlass cannot implement M={} N={} K={}"
,
M
,
N
,
K
));
}
// Initialize CUTLASS kernel with arguments and workspace pointer
status
=
gemm_op
.
initialize
(
arguments
,
workspace
.
getPtr
());
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot initialize"
);
}
status
=
gemm_op
();
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
throw
std
::
runtime_error
(
"cutlass cannot run"
);
}
});
return
out
;
}
src/kernels/gemm_f16.h
View file @
e9ad0535
...
...
@@ -7,6 +7,6 @@ Tensor gemm_f16(
Tensor
input
,
// FP16
Tensor
weight
,
// FP16
Tensor
out
,
// FP16
float
alpha
,
float
bet
a
Tensor
bias
,
float
alph
a
);
\ No newline at end of file
src/kernels/gemm_w8a8.cu
View file @
e9ad0535
...
...
@@ -82,7 +82,7 @@ Tensor gemm_w8a8_fp16(Tensor input, // INT8
// auto out = bias.to(device).view({1, -1}).repeat({M, 1});
if
(
!
out
.
valid
())
{
auto
out_shape
=
input
.
shape
;
auto
out_shape
=
TensorShape
(
input
.
shape
.
dataExtent
)
;
out_shape
[
-
1
]
=
N
;
out
=
Tensor
::
empty
(
out_shape
,
Tensor
::
FP16
,
input
.
device
());
}
...
...
src/kernels/misc_kernels.cu
View file @
e9ad0535
...
...
@@ -2,6 +2,8 @@
#include "misc_kernels.h"
#include "dispatch_utils.h"
namespace
nunchaku
::
kernels
{
Tensor
add
(
Tensor
a
,
Tensor
b
)
{
assert
(
a
.
shape
.
dataExtent
==
b
.
shape
.
dataExtent
);
assert
(
a
.
dtype
()
==
b
.
dtype
());
...
...
@@ -34,11 +36,11 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
constexpr
int
unroll
=
8
;
assert
((
uintptr_t
)
x
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
((
uintptr_t
)
scale
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
(
!
scale
.
valid
()
||
(
uintptr_t
)
scale
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
((
uintptr_t
)
bias
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
(
x
.
numel
()
%
unroll
==
0
);
assert
(
scale
.
numel
()
%
unroll
==
0
);
assert
(
!
scale
.
valid
()
||
scale
.
numel
()
%
unroll
==
0
);
assert
(
bias
.
numel
()
%
unroll
==
0
);
int
threadsPerBlock
=
1024
;
...
...
@@ -47,8 +49,60 @@ void mul_add(Tensor x, Tensor scale, Tensor bias) {
auto
stream
=
getCurrentCUDAStream
();
dispatch
(
x
.
scalar_type
(),
[
&
]
<
typename
scalar_t
>
()
{
mul_add_kernel
<
scalar_t
,
unroll
><<<
blocksPerGrid
,
threadsPerBlock
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
scalar_t
>
(),
bias
.
data_ptr
<
scalar_t
>
(),
x
.
numel
(),
scale
.
numel
(),
bias
.
numel
());
if
(
scale
.
valid
())
{
mul_add_kernel
<
scalar_t
,
unroll
,
false
><<<
blocksPerGrid
,
threadsPerBlock
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
scalar_t
>
(),
bias
.
data_ptr
<
scalar_t
>
(),
0
,
x
.
numel
(),
scale
.
numel
(),
bias
.
numel
(),
0
,
0
,
0
);
}
else
{
mul_add_kernel
<
scalar_t
,
unroll
,
true
><<<
blocksPerGrid
,
threadsPerBlock
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
nullptr
,
bias
.
data_ptr
<
scalar_t
>
(),
0
,
x
.
numel
(),
1
,
bias
.
numel
(),
0
,
0
,
0
);
}
});
}
void
mul_add_batch
(
Tensor
x
,
Tensor
scale
,
bool
batch_scale
,
double
scale_shift
,
Tensor
bias
,
bool
batch_bias
)
{
const
int
batch_size
=
x
.
shape
[
0
];
assert
(
!
batch_scale
||
scale
.
shape
[
0
]
==
batch_size
);
assert
(
!
batch_bias
||
bias
.
shape
[
0
]
==
batch_size
);
const
int
numel
=
x
.
numel
()
/
batch_size
;
const
int
numel_scale
=
scale
.
valid
()
?
(
scale
.
numel
()
/
(
batch_scale
?
batch_size
:
1
))
:
1
;
const
int
numel_bias
=
bias
.
numel
()
/
(
batch_bias
?
batch_size
:
1
);
assert
(
numel
%
numel_scale
==
0
);
assert
(
numel
%
numel_bias
==
0
);
assert
(
!
scale
.
valid
()
||
x
.
dtype
()
==
scale
.
dtype
());
assert
(
x
.
dtype
()
==
bias
.
dtype
());
constexpr
int
unroll
=
8
;
assert
((
uintptr_t
)
x
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
(
!
scale
.
valid
()
||
(
uintptr_t
)
scale
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
((
uintptr_t
)
bias
.
data_ptr
()
%
(
x
.
scalar_size
()
*
unroll
)
==
0
);
assert
(
numel
%
unroll
==
0
);
assert
(
!
scale
.
valid
()
||
numel_scale
%
unroll
==
0
);
assert
(
numel_bias
%
unroll
==
0
);
int
threadsPerBlock
=
1024
;
dim3
grid
(
ceilDiv
(
numel
,
threadsPerBlock
*
unroll
),
batch_size
);
auto
stream
=
getCurrentCUDAStream
();
dispatch
(
x
.
scalar_type
(),
[
&
]
<
typename
scalar_t
>
()
{
if
(
scale
.
valid
())
{
mul_add_kernel
<
scalar_t
,
unroll
,
false
><<<
grid
,
threadsPerBlock
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
scalar_t
>
(),
bias
.
data_ptr
<
scalar_t
>
(),
(
scalar_t
)
scale_shift
,
numel
,
numel_scale
,
numel_bias
,
x
.
stride
(
0
),
batch_scale
?
scale
.
stride
(
0
)
:
0
,
batch_bias
?
bias
.
stride
(
0
)
:
0
);
}
else
{
mul_add_kernel
<
scalar_t
,
unroll
,
true
><<<
grid
,
threadsPerBlock
,
0
,
stream
>>>
(
x
.
data_ptr
<
scalar_t
>
(),
nullptr
,
bias
.
data_ptr
<
scalar_t
>
(),
(
scalar_t
)
scale_shift
,
numel
,
1
,
numel_bias
,
x
.
stride
(
0
),
0
,
batch_bias
?
bias
.
stride
(
0
)
:
0
);
}
});
}
...
...
@@ -219,7 +273,7 @@ Tensor topk(Tensor x, int k) {
assert
(
k
<=
N
);
assert
(
k
<=
MAXK
);
auto
outShape
=
x
.
shape
;
auto
outShape
=
TensorShape
(
x
.
shape
.
dataExtent
)
;
outShape
[
-
1
]
=
k
;
outShape
.
dataStride
.
clear
();
...
...
@@ -252,4 +306,6 @@ template std::array<Tensor, 2> split_mod<2>(Tensor input);
template
std
::
array
<
Tensor
,
3
>
split_mod
<
3
>
(
Tensor
input
);
template
std
::
array
<
Tensor
,
4
>
split_mod
<
4
>
(
Tensor
input
);
template
std
::
array
<
Tensor
,
5
>
split_mod
<
5
>
(
Tensor
input
);
template
std
::
array
<
Tensor
,
6
>
split_mod
<
6
>
(
Tensor
input
);
\ No newline at end of file
template
std
::
array
<
Tensor
,
6
>
split_mod
<
6
>
(
Tensor
input
);
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/misc_kernels.h
View file @
e9ad0535
...
...
@@ -3,8 +3,12 @@
#include "common.h"
#include "Tensor.h"
namespace
nunchaku
::
kernels
{
Tensor
add
(
Tensor
a
,
Tensor
b
);
void
mul_add
(
Tensor
x
,
Tensor
scale
,
Tensor
bias
);
void
mul_add_batch
(
Tensor
x
,
Tensor
scale
,
bool
batch_scale
,
double
scale_shift
,
Tensor
bias
,
bool
batch_bias
);
Tensor
embedding
(
Tensor
input_id
,
Tensor
lookup
);
Tensor
argmax_sample
(
Tensor
logits
);
void
splitqkv
(
Tensor
qkv
,
Tensor
q
,
Tensor
k
,
Tensor
v
);
...
...
@@ -16,4 +20,6 @@ void cast(Tensor input, Tensor output);
Tensor
topk
(
Tensor
x
,
int
k
);
template
<
size_t
N
>
std
::
array
<
Tensor
,
N
>
split_mod
(
Tensor
input
);
\ No newline at end of file
std
::
array
<
Tensor
,
N
>
split_mod
(
Tensor
input
);
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/misc_kernels_impl.cuh
View file @
e9ad0535
...
...
@@ -7,6 +7,8 @@
#include "utils.cuh"
#include "activation_kernels_impl.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
T
>
__global__
void
add_kernel
(
T
*
a
,
T
*
b
,
T
*
c
,
size_t
length
)
{
...
...
@@ -21,9 +23,9 @@ struct alignas(sizeof(T) * unroll) Tvec {
T
data
[
unroll
];
};
template
<
typename
T
,
int
unroll
>
__global__
void
mul_add_kernel
(
T
*
x
,
T
*
scale
,
T
*
bias
,
size_t
length
,
int
mod_scale
,
int
mod_bias
)
{
template
<
typename
T
,
int
unroll
,
bool
no_scale
>
__global__
void
mul_add_kernel
(
T
*
x
,
T
*
scale
,
T
*
bias
,
T
scale_shift
,
size_t
length
,
int
mod_scale
,
int
mod_bias
,
int64_t
batch_stride_x
,
int64_t
batch_stride_scale
,
int64_t
batch_stride_bias
)
{
const
int
batch_id
=
blockIdx
.
y
;
int
thread
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
i
=
thread
*
unroll
;
int
i_scale
=
i
%
mod_scale
;
...
...
@@ -33,15 +35,20 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_s
return
;
}
using
Tvec
=
::
Tvec
<
T
,
unroll
>
;
using
Tvec
=
nunchaku
::
kernels
::
Tvec
<
T
,
unroll
>
;
Tvec
rx
=
*
reinterpret_cast
<
Tvec
*>
(
&
x
[
i
]);
Tvec
rscale
=
*
reinterpret_cast
<
Tvec
*>
(
&
scale
[
i_scale
]);
Tvec
rbias
=
*
reinterpret_cast
<
Tvec
*>
(
&
bias
[
i_bias
]);
Tvec
rx
=
*
reinterpret_cast
<
Tvec
*>
(
&
x
[
i
+
batch_stride_x
*
batch_id
]);
Tvec
rscale
=
*
reinterpret_cast
<
Tvec
*>
(
&
scale
[
i_scale
+
batch_stride_scale
*
batch_id
]);
Tvec
rbias
=
*
reinterpret_cast
<
Tvec
*>
(
&
bias
[
i_bias
+
batch_stride_bias
*
batch_id
]);
#pragma unroll
for
(
int
k
=
0
;
k
<
unroll
;
k
++
)
{
T
tmp
=
rx
.
data
[
k
]
*
rscale
.
data
[
k
]
+
rbias
.
data
[
k
];
T
tmp
;
if
constexpr
(
no_scale
)
{
tmp
=
rx
.
data
[
k
]
+
rbias
.
data
[
k
];
}
else
{
tmp
=
rx
.
data
[
k
]
*
(
rscale
.
data
[
k
]
+
scale_shift
)
+
rbias
.
data
[
k
];
}
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
tmp
=
__hmin
(
tmp
,
(
half
)
65504
);
tmp
=
__hmax
(
tmp
,
(
half
)
-
65504
);
...
...
@@ -49,7 +56,7 @@ __global__ void mul_add_kernel(T *x, T *scale, T *bias, size_t length, int mod_s
rx
.
data
[
k
]
=
tmp
;
}
*
reinterpret_cast
<
Tvec
*>
(
&
x
[
i
])
=
rx
;
*
reinterpret_cast
<
Tvec
*>
(
&
x
[
i
+
batch_stride_x
*
batch_id
])
=
rx
;
// #pragma unroll
// for (int k = 0; k < unroll; k++) {
...
...
@@ -127,8 +134,8 @@ __global__ void quant_kernel_static(const T * input, int8_t * output, T scale, s
return
;
}
using
Tvec
=
::
Tvec
<
T
,
unroll
>
;
using
I8vec
=
::
Tvec
<
int8_t
,
unroll
>
;
using
Tvec
=
nunchaku
::
kernels
::
Tvec
<
T
,
unroll
>
;
using
I8vec
=
nunchaku
::
kernels
::
Tvec
<
int8_t
,
unroll
>
;
Tvec
rinput
=
*
reinterpret_cast
<
const
Tvec
*>
(
&
input
[
i
]);
I8vec
routput
;
...
...
@@ -149,8 +156,8 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output,
return
;
}
using
Tvec
=
::
Tvec
<
T
,
unroll
>
;
using
I8vec
=
::
Tvec
<
int8_t
,
unroll
>
;
using
Tvec
=
nunchaku
::
kernels
::
Tvec
<
T
,
unroll
>
;
using
I8vec
=
nunchaku
::
kernels
::
Tvec
<
int8_t
,
unroll
>
;
Tvec
rinput
=
*
reinterpret_cast
<
const
Tvec
*>
(
&
input
[
i
]);
I8vec
routput
;
...
...
@@ -168,8 +175,8 @@ template<typename Tin, typename Tout, int unroll>
__global__
void
cast_kernel
(
const
Tin
*
input
,
Tout
*
output
,
size_t
length
)
{
const
int
i
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
unroll
;
using
Tvec_in
=
::
Tvec
<
Tin
,
unroll
>
;
using
Tvec_out
=
::
Tvec
<
Tout
,
unroll
>
;
using
Tvec_in
=
nunchaku
::
kernels
::
Tvec
<
Tin
,
unroll
>
;
using
Tvec_out
=
nunchaku
::
kernels
::
Tvec
<
Tout
,
unroll
>
;
Tvec_in
rinput
=
*
reinterpret_cast
<
const
Tvec_in
*>
(
&
input
[
i
]);
Tvec_out
routput
;
...
...
@@ -250,4 +257,6 @@ void topk_kernel(const T *input, int *output, int N, int strideInput, int numRow
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
output
[
row
*
K
+
i
]
=
idx
[
K
-
i
-
1
];
}
}
\ No newline at end of file
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_base.cuh
0 → 100644
View file @
e9ad0535
This diff is collapsed.
Click to expand it.
src/kernels/gemm_utils.cuh
→
src/kernels/
zgemm/
gemm_utils.cuh
View file @
e9ad0535
...
...
@@ -2,7 +2,9 @@
#include <cstdint>
#include "common.h"
#include "utils.cuh"
#include "../utils.cuh"
namespace
nunchaku
::
kernels
{
static
constexpr
int
clamp
(
int
val
,
int
min
,
int
max
)
{
if
(
val
<
min
)
...
...
@@ -74,25 +76,19 @@ static void store(T *addr, T val) {
*
addr
=
val
;
}
template
<
typename
T
>
__device__
__forceinline__
float2
half22float2
(
T
val
);
template
<
>
__device__
__forceinline__
float2
half22float2
<
half2
>
(
half2
val
)
{
static
float2
half22float2
(
half2
val
)
{
return
__half22float2
(
val
);
}
template
<
>
__device__
__forceinline__
float2
half22float2
<
__nv_bfloat162
>
(
__nv_bfloat162
val
)
{
static
float2
half22float2
(
__nv_bfloat162
val
)
{
return
__bfloat1622float2
(
val
);
}
template
<
typename
T
>
__device__
__forceinline__
T
float22half2
(
float2
val
);
static
T
float22half2
(
float2
val
)
=
delete
;
template
<
>
__device__
__forceinline__
...
...
@@ -108,7 +104,7 @@ __nv_bfloat162 float22half2<__nv_bfloat162>(float2 val) {
template
<
typename
T
>
__device__
__forceinline__
void
unused_var
(
T
&
val
,
bool
alwaysfalse
)
{
static
void
unused_var
(
T
&
val
,
bool
alwaysfalse
)
{
volatile
T
*
ptr
=
nullptr
;
if
(
alwaysfalse
)
{
*
ptr
=
val
;
...
...
@@ -218,7 +214,7 @@ static float cuda_sigmoidf (float a)
template
<
typename
T
>
__device__
__forceinline__
static
T
gelu_half2
(
T
x
)
{
float2
xf
=
half22float2
<
T
>
(
x
);
float2
xf
=
half22float2
(
x
);
float2
x3f
=
xf
*
xf
*
xf
;
float
t1
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
.
x
+
(
0.044715
f
*
x3f
.
x
)));
float
t2
=
0.5
f
+
0.5
f
*
cuda_tanhf
(
0.79788456
f
*
(
xf
.
y
+
(
0.044715
f
*
x3f
.
y
)));
...
...
@@ -242,6 +238,25 @@ static T silu(const T &x) {
// return (T)__fdividef((float)x, 1.0f + __expf((float)-x));
}
__device__
__forceinline__
static
half2
h2div
(
half2
a
,
half2
b
)
{
float2
af
=
half22float2
(
a
);
float2
bf
=
half22float2
(
b
);
float2
of
;
of
.
x
=
__fdividef
(
af
.
x
,
bf
.
x
);
of
.
y
=
__fdividef
(
af
.
y
,
bf
.
y
);
return
float22half2
<
half2
>
(
of
);
};
__device__
__forceinline__
static
__nv_bfloat162
h2div
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
af
=
half22float2
(
a
);
float2
bf
=
half22float2
(
b
);
float2
of
;
of
.
x
=
__fdividef
(
af
.
x
,
bf
.
x
);
of
.
y
=
__fdividef
(
af
.
y
,
bf
.
y
);
return
float22half2
<
__nv_bfloat162
>
(
of
);
};
__device__
__forceinline__
static
void
reduce_add
(
float
*
addr
,
float
val
)
{
asm
volatile
(
"red.relaxed.gpu.global.add.f32 [%0], %1;"
::
"l"
(
addr
),
"f"
(
val
));
...
...
@@ -254,4 +269,6 @@ static void unrolled_loop(F &&lambda) {
(
lambda
.
template
operator
()
<
Is
>(),
...);
};
call
(
std
::
make_integer_sequence
<
int
,
cnt
>
());
}
\ No newline at end of file
}
};
// namespace nunchaku::kernels
\ No newline at end of file
src/kernels/zgemm/gemm_w4a4.cu
0 → 100644
View file @
e9ad0535
#include "zgemm.h"
#include "gemm_w4a4_launch.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
F
>
static
void
invoke_launch
(
Tensor
::
ScalarType
dtype
,
F
&&
launch
)
{
if
(
dtype
==
Tensor
::
FP16
)
{
launch
.
template
operator
()
<
GEMMConfig_W4A4_FP16
>();
}
else
if
(
dtype
==
Tensor
::
BF16
)
{
launch
.
template
operator
()
<
GEMMConfig_W4A4_BF16
>();
}
else
{
assert
(
false
);
}
}
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
)
{
invoke_launch
(
ascales
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
gemm_w4a4
(
act
,
wgt
,
out
,
qout
,
ascales
,
wscales
,
oscales
,
poolout
,
lora_act_in
,
lora_up
,
lora_down
,
lora_act_out
,
norm_q
,
norm_k
,
rotary_emb
,
bias
,
smooth_factor
,
out_vk
,
out_linearattn
,
act_unsigned
,
lora_scales
,
fuse_silu
);
});
}
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
)
{
invoke_launch
(
q
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
linearattn_vk_mul_q
(
q
,
vk
);
});
}
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act_fuse_lora
(
input
,
output
,
oscales
,
lora_down
,
lora_act_out
,
smooth
,
fuse_glu
);
});
}
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_act
(
input
,
output
,
oscales
);
});
}
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
)
{
invoke_launch
(
input
.
dtype
(),
[
&
]
<
typename
Config
>
()
{
GEMM_W4A4_Launch
<
Config
>::
quantize_w4a4_wgt
(
input
,
output
,
oscales
);
});
}
};
\ No newline at end of file
src/kernels/gemm_w4a4.cu
→
src/kernels/
zgemm/
gemm_w4a4.cu
h
View file @
e9ad0535
This diff is collapsed.
Click to expand it.
src/kernels/zgemm/gemm_w4a4_launch.cuh
0 → 100644
View file @
e9ad0535
#include "gemm_w4a4.cuh"
namespace
nunchaku
::
kernels
{
template
<
typename
Config
>
class
GEMM_W4A4_Launch
{
using
GEMM
=
GEMM_W4A4
<
Config
>
;
using
LoraRanks
=
std
::
integer_sequence
<
int
,
0
,
32
,
48
,
64
,
80
,
96
>
;
// using LoraRanks = std::integer_sequence<int, 32>;
using
packed_act_t
=
typename
GEMM
::
packed_act_t
;
using
packed_wgt_t
=
typename
GEMM
::
packed_wgt_t
;
using
packed_ascale_t
=
typename
GEMM
::
packed_ascale_t
;
using
packed_wscale_t
=
typename
GEMM
::
packed_wscale_t
;
using
packed_fpsum_t
=
typename
GEMM
::
packed_fpsum_t
;
using
half_t
=
typename
GEMM
::
half_t
;
public:
static
void
gemm_w4a4
(
Tensor
act
,
// packed act [M, K / 2]
Tensor
wgt
,
// packed act [N, K / 2]
Tensor
out
,
// linear [M, N]
Tensor
qout
,
// packed act [M, N / 2]
Tensor
ascales
,
// packed as [K / 64, M]
Tensor
wscales
,
// packed ws [K / 64, N]
Tensor
oscales
,
// packed as [N / 64, M]
Tensor
poolout
,
// linear [M / PoolSize, N]
Tensor
lora_act_in
,
// packed lora_act [M, R]
Tensor
lora_up
,
// packed lora_wgt [N, R]
Tensor
lora_down
,
// packed lora_wgt [N, R]
Tensor
lora_act_out
,
// packed lora_act [M, R]
Tensor
norm_q
,
// linear [HEAD_DIM]
Tensor
norm_k
,
// linear [HEAD_DIM]
Tensor
rotary_emb
,
// linear [M, HEAD_DIM / 2, 2, 2]
Tensor
bias
,
// packed ws [N]
Tensor
smooth_factor
,
// packed ws [N], for quantization of the next layer
Tensor
out_vk
,
// linear [B, num_heads, head_dim + 1, head_dim]
Tensor
out_linearattn
,
// linear [B, (M), N / 3]
bool
act_unsigned
,
std
::
vector
<
float
>
lora_scales
,
// [R / 16]
bool
fuse_silu
);
static
void
quantize_w4a4_act_fuse_lora
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
,
Tensor
lora_down
,
Tensor
lora_act_out
,
Tensor
smooth
,
bool
fuse_glu
);
static
void
quantize_w4a4_act
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
quantize_w4a4_wgt
(
Tensor
input
,
Tensor
output
,
Tensor
oscales
);
static
void
linearattn_vk_mul_q
(
Tensor
q
,
Tensor
vk
);
};
};
// namespace nunchaku::kernels
\ No newline at end of file
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment