Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
004e4e31
Commit
004e4e31
authored
Feb 26, 2025
by
Samuel Tesfai
Browse files
Adding INT16 to ScalarType Tensor
Finalizing deepcompressor migration
parent
218d333e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
8 deletions
+10
-8
nunchaku/csrc/pybind.cpp
nunchaku/csrc/pybind.cpp
+1
-4
nunchaku/models/linear.py
nunchaku/models/linear.py
+3
-3
src/Tensor.h
src/Tensor.h
+2
-1
src/interop/torch.cpp
src/interop/torch.cpp
+4
-0
No files found.
nunchaku/csrc/pybind.cpp
View file @
004e4e31
...
@@ -11,9 +11,6 @@
...
@@ -11,9 +11,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"awq_gemm_forward_cuda"
,
&
awq_gemm_forward_cuda
,
"AWQ quantized GEMM kernel."
);
m
.
def
(
"gemv_awq"
,
&
gemv_awq
,
"AWQ quantized GEMV kernel."
);
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
py
::
class_
<
QuantizedFluxModel
>
(
m
,
"QuantizedFluxModel"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
.
def
(
"init"
,
&
QuantizedFluxModel
::
init
,
...
@@ -76,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -76,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
;
;
m
.
def_submodule
(
"ops"
)
m
.
def_submodule
(
"ops"
)
.
def
(
"gemm_
w4a4
"
,
nunchaku
::
ops
::
gemm_
w4a4
)
.
def
(
"gemm_
cuda
"
,
nunchaku
::
ops
::
gemm_
cuda
)
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
.
def
(
"gemv_awq"
,
nunchaku
::
ops
::
gemv_awq
)
;
;
...
...
nunchaku/models/linear.py
View file @
004e4e31
...
@@ -6,7 +6,7 @@ import warnings
...
@@ -6,7 +6,7 @@ import warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nunchaku.
csrc.load
import
_C
from
nunchaku.
_C.ops
import
gemm_cuda
,
gemv_awq
from
.tinychat_utils
import
ceil_num_groups
,
convert_to_tinychat_w4x16y16_linear_weight
from
.tinychat_utils
import
ceil_num_groups
,
convert_to_tinychat_w4x16y16_linear_weight
__all__
=
[
"W4Linear"
]
__all__
=
[
"W4Linear"
]
...
@@ -78,7 +78,7 @@ class W4Linear(nn.Module):
...
@@ -78,7 +78,7 @@ class W4Linear(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
x
.
numel
()
/
x
.
shape
[
-
1
]
<
8
:
if
x
.
numel
()
/
x
.
shape
[
-
1
]
<
8
:
out
=
_C
.
awq_gemv_forward_cuda
(
out
=
gemv_awq
(
x
,
x
,
self
.
qweight
,
self
.
qweight
,
self
.
scales
,
self
.
scales
,
...
@@ -89,7 +89,7 @@ class W4Linear(nn.Module):
...
@@ -89,7 +89,7 @@ class W4Linear(nn.Module):
self
.
group_size
,
self
.
group_size
,
)
)
else
:
else
:
out
=
_C
.
awq_gemm_forward
_cuda
(
x
,
self
.
qweight
,
self
.
scales
,
self
.
scaled_zeros
)
out
=
gemm
_cuda
(
x
,
self
.
qweight
,
self
.
scales
,
self
.
scaled_zeros
)
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
return
out
...
...
src/Tensor.h
View file @
004e4e31
...
@@ -217,7 +217,7 @@ class Tensor {
...
@@ -217,7 +217,7 @@ class Tensor {
public:
public:
enum
ScalarType
{
enum
ScalarType
{
INVALID_SCALAR_TYPE
,
INVALID_SCALAR_TYPE
,
INT8
,
INT32
,
INT64
,
INT8
,
INT16
,
INT32
,
INT64
,
FP16
,
FP32
,
BF16
FP16
,
FP32
,
BF16
};
};
...
@@ -540,6 +540,7 @@ public:
...
@@ -540,6 +540,7 @@ public:
inline
const
std
::
map
<
Tensor
::
ScalarType
,
size_t
>
Tensor
::
scalarSize
=
{
inline
const
std
::
map
<
Tensor
::
ScalarType
,
size_t
>
Tensor
::
scalarSize
=
{
{
INT8
,
1
},
{
INT8
,
1
},
{
INT16
,
2
},
{
INT32
,
4
},
{
INT32
,
4
},
{
INT64
,
8
},
{
INT64
,
8
},
{
FP16
,
2
},
{
FP16
,
2
},
...
...
src/interop/torch.cpp
View file @
004e4e31
...
@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) {
...
@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) {
{
at
::
ScalarType
::
Float
,
Tensor
::
FP32
},
{
at
::
ScalarType
::
Float
,
Tensor
::
FP32
},
{
at
::
ScalarType
::
Half
,
Tensor
::
FP16
},
{
at
::
ScalarType
::
Half
,
Tensor
::
FP16
},
{
at
::
ScalarType
::
BFloat16
,
Tensor
::
BF16
},
{
at
::
ScalarType
::
BFloat16
,
Tensor
::
BF16
},
{
at
::
ScalarType
::
Short
,
Tensor
::
INT16
},
};
};
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
...
@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) {
...
@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) {
{
Tensor
::
FP32
,
at
::
ScalarType
::
Float
},
{
Tensor
::
FP32
,
at
::
ScalarType
::
Float
},
{
Tensor
::
FP16
,
at
::
ScalarType
::
Half
},
{
Tensor
::
FP16
,
at
::
ScalarType
::
Half
},
{
Tensor
::
BF16
,
at
::
ScalarType
::
BFloat16
},
{
Tensor
::
BF16
,
at
::
ScalarType
::
BFloat16
},
{
Tensor
::
INT16
,
at
::
ScalarType
::
Short
},
};
};
c10
::
TensorOptions
opts
(
mapType
.
at
(
input
.
scalar_type
()));
c10
::
TensorOptions
opts
(
mapType
.
at
(
input
.
scalar_type
()));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment