Commit 004e4e31 authored by Samuel Tesfai's avatar Samuel Tesfai
Browse files

Adding INT16 to ScalarType Tensor

Finalizing deepcompressor migration
parent 218d333e
...@@ -11,9 +11,6 @@ ...@@ -11,9 +11,6 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("awq_gemm_forward_cuda", &awq_gemm_forward_cuda, "AWQ quantized GEMM kernel.");
m.def("gemv_awq", &gemv_awq, "AWQ quantized GEMV kernel.");
py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel") py::class_<QuantizedFluxModel>(m, "QuantizedFluxModel")
.def(py::init<>()) .def(py::init<>())
.def("init", &QuantizedFluxModel::init, .def("init", &QuantizedFluxModel::init,
...@@ -76,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -76,7 +73,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
; ;
m.def_submodule("ops") m.def_submodule("ops")
.def("gemm_w4a4", nunchaku::ops::gemm_w4a4) .def("gemm_cuda", nunchaku::ops::gemm_cuda)
.def("gemv_awq", nunchaku::ops::gemv_awq) .def("gemv_awq", nunchaku::ops::gemv_awq)
; ;
......
...@@ -6,7 +6,7 @@ import warnings ...@@ -6,7 +6,7 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from nunchaku.csrc.load import _C from nunchaku._C.ops import gemm_cuda, gemv_awq
from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight from .tinychat_utils import ceil_num_groups, convert_to_tinychat_w4x16y16_linear_weight
__all__ = ["W4Linear"] __all__ = ["W4Linear"]
...@@ -78,7 +78,7 @@ class W4Linear(nn.Module): ...@@ -78,7 +78,7 @@ class W4Linear(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
if x.numel() / x.shape[-1] < 8: if x.numel() / x.shape[-1] < 8:
out = _C.awq_gemv_forward_cuda( out = gemv_awq(
x, x,
self.qweight, self.qweight,
self.scales, self.scales,
...@@ -89,7 +89,7 @@ class W4Linear(nn.Module): ...@@ -89,7 +89,7 @@ class W4Linear(nn.Module):
self.group_size, self.group_size,
) )
else: else:
out = _C.awq_gemm_forward_cuda(x, self.qweight, self.scales, self.scaled_zeros) out = gemm_cuda(x, self.qweight, self.scales, self.scaled_zeros)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out return out
......
...@@ -217,7 +217,7 @@ class Tensor { ...@@ -217,7 +217,7 @@ class Tensor {
public: public:
enum ScalarType { enum ScalarType {
INVALID_SCALAR_TYPE, INVALID_SCALAR_TYPE,
INT8, INT32, INT64, INT8, INT16, INT32, INT64,
FP16, FP32, BF16 FP16, FP32, BF16
}; };
...@@ -540,6 +540,7 @@ public: ...@@ -540,6 +540,7 @@ public:
inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = { inline const std::map<Tensor::ScalarType, size_t> Tensor::scalarSize = {
{INT8, 1}, {INT8, 1},
{INT16, 2},
{INT32, 4}, {INT32, 4},
{INT64, 8}, {INT64, 8},
{FP16, 2}, {FP16, 2},
......
...@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) { ...@@ -28,6 +28,8 @@ Tensor from_torch(at::Tensor input) {
{ at::ScalarType::Float, Tensor::FP32 }, { at::ScalarType::Float, Tensor::FP32 },
{ at::ScalarType::Half, Tensor::FP16 }, { at::ScalarType::Half, Tensor::FP16 },
{ at::ScalarType::BFloat16, Tensor::BF16 }, { at::ScalarType::BFloat16, Tensor::BF16 },
{ at::ScalarType::Short, Tensor::INT16 },
}; };
result.scalarType = mapType.at(input.scalar_type()); result.scalarType = mapType.at(input.scalar_type());
...@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) { ...@@ -53,6 +55,8 @@ at::Tensor to_torch(Tensor input) {
{ Tensor::FP32, at::ScalarType::Float }, { Tensor::FP32, at::ScalarType::Float },
{ Tensor::FP16, at::ScalarType::Half }, { Tensor::FP16, at::ScalarType::Half },
{ Tensor::BF16, at::ScalarType::BFloat16 }, { Tensor::BF16, at::ScalarType::BFloat16 },
{ Tensor::INT16, at::ScalarType::Short },
}; };
c10::TensorOptions opts(mapType.at(input.scalar_type())); c10::TensorOptions opts(mapType.at(input.scalar_type()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment