Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
eb8e460c
Commit
eb8e460c
authored
Sep 13, 2024
by
nicodafagood
Browse files
update mygq
parent
23fdbb68
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
156 additions
and
156 deletions
+156
-156
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+1
-1
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+1
-1
csrc/ops.h
csrc/ops.h
+4
-4
csrc/pybind.cpp
csrc/pybind.cpp
+2
-2
csrc/quantization/mygq/compat.cuh
csrc/quantization/mygq/compat.cuh
+2
-2
csrc/quantization/mygq/matrix_view.cuh
csrc/quantization/mygq/matrix_view.cuh
+2
-2
csrc/quantization/mygq/q_gemm.cu
csrc/quantization/mygq/q_gemm.cu
+125
-125
csrc/quantization/mygq/qdq_2.cuh
csrc/quantization/mygq/qdq_2.cuh
+2
-2
csrc/quantization/mygq/qdq_3.cuh
csrc/quantization/mygq/qdq_3.cuh
+2
-2
csrc/quantization/mygq/qdq_4.cuh
csrc/quantization/mygq/qdq_4.cuh
+3
-3
csrc/quantization/mygq/qdq_8.cuh
csrc/quantization/mygq/qdq_8.cuh
+2
-2
csrc/quantization/mygq/qdq_util.cuh
csrc/quantization/mygq/qdq_util.cuh
+2
-2
setup.py
setup.py
+1
-1
vllm/config.py
vllm/config.py
+1
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+2
-2
vllm/model_executor/layers/quantization/mygq.py
vllm/model_executor/layers/quantization/mygq.py
+3
-3
No files found.
benchmarks/benchmark_latency.py
View file @
eb8e460c
...
@@ -92,7 +92,7 @@ if __name__ == '__main__':
...
@@ -92,7 +92,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
choices
=
[
'awq'
,
'gptq'
,
'myq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'my
g
q'
,
'squeezellm'
,
None
],
default
=
None
)
default
=
None
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
'--input-len'
,
type
=
int
,
default
=
32
)
...
...
benchmarks/benchmark_throughput.py
View file @
eb8e460c
...
@@ -258,7 +258,7 @@ if __name__ == "__main__":
...
@@ -258,7 +258,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
choices
=
[
'awq'
,
'gptq'
,
'myq'
,
'squeezellm'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'my
g
q'
,
'squeezellm'
,
None
],
default
=
None
)
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
parser
.
add_argument
(
"--n"
,
...
...
csrc/ops.h
View file @
eb8e460c
...
@@ -115,16 +115,16 @@ void gptq_shuffle(
...
@@ -115,16 +115,16 @@ void gptq_shuffle(
torch
::
Tensor
q_perm
,
torch
::
Tensor
q_perm
,
int
bit
);
int
bit
);
torch
::
Tensor
myq_gemm
(
torch
::
Tensor
my
g
q_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_myq_qzeros
,
torch
::
Tensor
b_my
g
q_qzeros
,
torch
::
Tensor
b_myq_scales
,
torch
::
Tensor
b_my
g
q_scales
,
torch
::
Tensor
b_g_idx
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
bool
use_exllama
,
int
bit
);
int
bit
);
void
myq_shuffle
(
void
my
g
q_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
torch
::
Tensor
q_perm
,
int
bit
);
int
bit
);
...
...
csrc/pybind.cpp
View file @
eb8e460c
...
@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"myq_gemm"
,
&
myq_gemm
,
"Quantized GEMM for myq"
);
ops
.
def
(
"my
g
q_gemm"
,
&
my
g
q_gemm
,
"Quantized GEMM for my
g
q"
);
ops
.
def
(
"myq_shuffle"
,
&
myq_shuffle
,
"Post processing for
GPTQ
"
);
ops
.
def
(
"my
g
q_shuffle"
,
&
my
g
q_shuffle
,
"Post processing for
mygq
"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
ops
.
def
(
"moe_align_block_size"
,
"moe_align_block_size"
,
...
...
csrc/quantization/myq/compat.cuh
→
csrc/quantization/my
g
q/compat.cuh
View file @
eb8e460c
...
@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
...
@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _compat_cuh
#define _compat_cuh
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
// atomicAdd for half types, to support CC < 7.x
// atomicAdd for half types, to support CC < 7.x
__device__
__forceinline__
void
atomicAdd_half
(
half
*
address
,
half
val
)
__device__
__forceinline__
void
atomicAdd_half
(
half
*
address
,
half
val
)
...
@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
...
@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
#endif
#endif
#endif
#endif
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/myq/matrix_view.cuh
→
csrc/quantization/my
g
q/matrix_view.cuh
View file @
eb8e460c
...
@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
...
@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
#include "qdq_util.cuh"
#include "qdq_util.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
class
MatrixView_half
class
MatrixView_half
{
{
...
@@ -269,6 +269,6 @@ public:
...
@@ -269,6 +269,6 @@ public:
}
}
};
};
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/myq/q_gemm.cu
→
csrc/quantization/my
g
q/q_gemm.cu
View file @
eb8e460c
...
@@ -21,7 +21,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
...
@@ -21,7 +21,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopq
#include "qdq_8.cuh"
#include "qdq_8.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
#define BLOCK_KN_SIZE 128
#define BLOCK_KN_SIZE 128
#define BLOCK_M_SIZE_MAX 8
#define BLOCK_M_SIZE_MAX 8
...
@@ -181,7 +181,7 @@ __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, co
...
@@ -181,7 +181,7 @@ __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, co
}
}
typedef
void
(
*
fp_gemm_half_q_half_myq_kernel
)
typedef
void
(
*
fp_gemm_half_q_half_my
g
q_kernel
)
(
(
const
half
*
,
const
half
*
,
const
uint32_t
*
,
const
uint32_t
*
,
...
@@ -197,12 +197,12 @@ typedef void (*fp_gemm_half_q_half_myq_kernel)
...
@@ -197,12 +197,12 @@ typedef void (*fp_gemm_half_q_half_myq_kernel)
template
<
bool
first_block
,
int
m_count
>
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_myq_4bit_kernel
__global__
void
gemm_half_q_half_my
g
q_4bit_kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
half
*
__restrict__
c
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_m
,
const
int
size_n
,
const
int
size_n
,
...
@@ -213,8 +213,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
...
@@ -213,8 +213,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
{
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q4_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q4_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
int
t
=
threadIdx
.
x
;
...
@@ -274,8 +274,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
...
@@ -274,8 +274,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
float
scales
[
4
];
float
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_f
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_f
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
...
@@ -292,8 +292,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
...
@@ -292,8 +292,8 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_f
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_f
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
...
@@ -307,10 +307,10 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
...
@@ -307,10 +307,10 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
int4
load_int4
=
*
b_ptr4
;
int4
load_int4
=
*
b_ptr4
;
half2
dq
[
4
][
4
];
half2
dq
[
4
][
4
];
dequant_4bit_8_myq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_myq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_myq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_myq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
#pragma unroll
#pragma unroll
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
for
(
int
m
=
0
;
m
<
m_count
;
m
++
)
...
@@ -339,12 +339,12 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
...
@@ -339,12 +339,12 @@ __global__ void gemm_half_q_half_myq_4bit_kernel
}
}
template
<
bool
first_block
,
int
m_count
>
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_myq_2bit_kernel
__global__
void
gemm_half_q_half_my
g
q_2bit_kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
half
*
__restrict__
c
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_m
,
const
int
size_n
,
const
int
size_n
,
...
@@ -355,8 +355,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
...
@@ -355,8 +355,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
{
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q2_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q2_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
int
t
=
threadIdx
.
x
;
...
@@ -414,8 +414,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
...
@@ -414,8 +414,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
// Initial group
// Initial group
int
zeros
[
4
];
int
zeros
[
4
];
half
scales
[
4
];
half
scales
[
4
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4
(
scales
,
group
,
n
);
// Column result
// Column result
half
block_c
[
m_count
][
4
]
=
{};
half
block_c
[
m_count
][
4
]
=
{};
...
@@ -427,8 +427,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
...
@@ -427,8 +427,8 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4
(
scales
,
group
,
n
);
}
}
#pragma unroll
#pragma unroll
...
@@ -470,12 +470,12 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
...
@@ -470,12 +470,12 @@ __global__ void gemm_half_q_half_myq_2bit_kernel
}
}
template
<
bool
first_block
,
int
m_count
>
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_myq_3bit_kernel
__global__
void
gemm_half_q_half_my
g
q_3bit_kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
half
*
__restrict__
c
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_m
,
const
int
size_n
,
const
int
size_n
,
...
@@ -486,8 +486,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
...
@@ -486,8 +486,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
{
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q3_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q3_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
int
t
=
threadIdx
.
x
;
...
@@ -545,8 +545,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
...
@@ -545,8 +545,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
// Initial group
// Initial group
int
zeros
[
4
];
int
zeros
[
4
];
half
scales
[
4
];
half
scales
[
4
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4
(
scales
,
group
,
n
);
// Column result
// Column result
half
block_c
[
m_count
][
4
]
=
{};
half
block_c
[
m_count
][
4
]
=
{};
...
@@ -558,8 +558,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
...
@@ -558,8 +558,8 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4
(
scales
,
group
,
n
);
}
}
#pragma unroll
#pragma unroll
...
@@ -601,12 +601,12 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
...
@@ -601,12 +601,12 @@ __global__ void gemm_half_q_half_myq_3bit_kernel
}
}
template
<
bool
first_block
,
int
m_count
>
template
<
bool
first_block
,
int
m_count
>
__global__
void
gemm_half_q_half_myq_8bit_kernel
__global__
void
gemm_half_q_half_my
g
q_8bit_kernel
(
(
const
half
*
__restrict__
a
,
const
half
*
__restrict__
a
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
half
*
__restrict__
c
,
half
*
__restrict__
c
,
const
int
size_m
,
const
int
size_m
,
const
int
size_n
,
const
int
size_n
,
...
@@ -617,8 +617,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
...
@@ -617,8 +617,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
{
{
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half
a_
(
a
,
size_m
,
size_k
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_half_rw
c_
(
c
,
size_m
,
size_n
);
MatrixView_q8_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q8_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
t
=
threadIdx
.
x
;
int
t
=
threadIdx
.
x
;
...
@@ -676,8 +676,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
...
@@ -676,8 +676,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
// Initial group
// Initial group
int
zeros
[
4
];
int
zeros
[
4
];
half
scales
[
4
];
half
scales
[
4
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4
(
scales
,
group
,
n
);
// Column result
// Column result
half
block_c
[
m_count
][
4
]
=
{};
half
block_c
[
m_count
][
4
]
=
{};
...
@@ -689,8 +689,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
...
@@ -689,8 +689,8 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4
(
scales
,
group
,
n
);
}
}
#pragma unroll
#pragma unroll
...
@@ -728,15 +728,15 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
...
@@ -728,15 +728,15 @@ __global__ void gemm_half_q_half_myq_8bit_kernel
}
}
}
}
fp_gemm_half_q_half_myq_kernel
pick_gemm_half_q_half_myq_kernel
(
fp_gemm_half_q_half_my
g
q_kernel
pick_gemm_half_q_half_my
g
q_kernel
(
bool
first_block
,
const
int
m_count
,
const
int
bit
)
bool
first_block
,
const
int
m_count
,
const
int
bit
)
{
{
#define SELECT_KERNEL(M_COUNT) \
#define SELECT_KERNEL(M_COUNT) \
if (m_count == M_COUNT) { \
if (m_count == M_COUNT) { \
if (bit == 2) return gemm_half_q_half_myq_2bit_kernel<true, M_COUNT>; \
if (bit == 2) return gemm_half_q_half_my
g
q_2bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_myq_3bit_kernel<true, M_COUNT>; \
if (bit == 3) return gemm_half_q_half_my
g
q_3bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_myq_4bit_kernel<true, M_COUNT>; \
if (bit == 4) return gemm_half_q_half_my
g
q_4bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_myq_8bit_kernel<true, M_COUNT>; \
if (bit == 8) return gemm_half_q_half_my
g
q_8bit_kernel<true, M_COUNT>; \
}
}
#if BLOCK_M_SIZE_MAX >= 1
#if BLOCK_M_SIZE_MAX >= 1
SELECT_KERNEL
(
1
);
SELECT_KERNEL
(
1
);
...
@@ -770,8 +770,8 @@ void gemm_half_q_half_cuda_part
...
@@ -770,8 +770,8 @@ void gemm_half_q_half_cuda_part
(
(
const
half
*
a
,
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_myq_qzeros
,
const
uint32_t
*
b_my
g
q_qzeros
,
const
half
*
b_myq_scales
,
const
half
*
b_my
g
q_scales
,
const
int
*
b_q_perm
,
const
int
*
b_q_perm
,
half
*
c
,
half
*
c
,
int
size_m
,
int
size_m
,
...
@@ -790,15 +790,15 @@ void gemm_half_q_half_cuda_part
...
@@ -790,15 +790,15 @@ void gemm_half_q_half_cuda_part
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
y
=
DIVIDE
(
size_m
,
m_count
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
gridDim
.
z
=
DIVIDE
(
size_k
,
BLOCK_KN_SIZE
);
fp_gemm_half_q_half_myq_kernel
kernel
=
pick_gemm_half_q_half_myq_kernel
(
true
,
m_count
,
bit
);
fp_gemm_half_q_half_my
g
q_kernel
kernel
=
pick_gemm_half_q_half_my
g
q_kernel
(
true
,
m_count
,
bit
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
a
,
a
,
b_q_weight
,
b_q_weight
,
b_myq_qzeros
,
b_my
g
q_qzeros
,
b_myq_scales
,
b_my
g
q_scales
,
c
,
c
,
size_m
,
size_m
,
size_n
,
size_n
,
...
@@ -813,8 +813,8 @@ __global__ void reconstruct_exllama_8bit_kernel
...
@@ -813,8 +813,8 @@ __global__ void reconstruct_exllama_8bit_kernel
(
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
const
int
size_k
,
const
int
size_k
,
const
int
size_n
,
const
int
size_n
,
const
int
groups
,
const
int
groups
,
...
@@ -822,8 +822,8 @@ __global__ void reconstruct_exllama_8bit_kernel
...
@@ -822,8 +822,8 @@ __global__ void reconstruct_exllama_8bit_kernel
)
)
{
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q8_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q8_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
...
@@ -857,8 +857,8 @@ __global__ void reconstruct_exllama_8bit_kernel
...
@@ -857,8 +857,8 @@ __global__ void reconstruct_exllama_8bit_kernel
// Initial zeros/scale
// Initial zeros/scale
int
zeros
[
4
];
int
zeros
[
4
];
half2
scales
[
4
];
half2
scales
[
4
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
__syncthreads
();
__syncthreads
();
...
@@ -871,8 +871,8 @@ __global__ void reconstruct_exllama_8bit_kernel
...
@@ -871,8 +871,8 @@ __global__ void reconstruct_exllama_8bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
}
}
for
(
int
p
=
0
;
p
<
4
;
p
++
)
for
(
int
p
=
0
;
p
<
4
;
p
++
)
...
@@ -915,8 +915,8 @@ __global__ void reconstruct_exllama_4bit_kernel
...
@@ -915,8 +915,8 @@ __global__ void reconstruct_exllama_4bit_kernel
(
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
const
int
size_k
,
const
int
size_k
,
const
int
size_n
,
const
int
size_n
,
const
int
groups
,
const
int
groups
,
...
@@ -924,8 +924,8 @@ __global__ void reconstruct_exllama_4bit_kernel
...
@@ -924,8 +924,8 @@ __global__ void reconstruct_exllama_4bit_kernel
)
)
{
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q4_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q4_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
...
@@ -961,8 +961,8 @@ __global__ void reconstruct_exllama_4bit_kernel
...
@@ -961,8 +961,8 @@ __global__ void reconstruct_exllama_4bit_kernel
half2
scales
[
4
];
half2
scales
[
4
];
half2
z1z16
[
4
][
2
];
half2
z1z16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
half2
y1y16
[
4
][
2
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
...
@@ -979,8 +979,8 @@ __global__ void reconstruct_exllama_4bit_kernel
...
@@ -979,8 +979,8 @@ __global__ void reconstruct_exllama_4bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
0
]
+
1
,
z1z16
[
0
],
y1y16
[
0
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
1
]
+
1
,
z1z16
[
1
],
y1y16
[
1
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
dequant_4bit_8_prep_zero
(
zeros
[
2
]
+
1
,
z1z16
[
2
],
y1y16
[
2
]);
...
@@ -993,10 +993,10 @@ __global__ void reconstruct_exllama_4bit_kernel
...
@@ -993,10 +993,10 @@ __global__ void reconstruct_exllama_4bit_kernel
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
const
int4
*
b_ptr4
=
(
int4
*
)
b_ptr
;
int4
load_int4
=
*
b_ptr4
;
int4
load_int4
=
*
b_ptr4
;
dequant_4bit_8_myq
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
x
,
dq
[
0
],
z1z16
[
0
],
y1y16
[
0
],
size_n
,
false
);
dequant_4bit_8_myq
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
y
,
dq
[
1
],
z1z16
[
1
],
y1y16
[
1
],
size_n
,
false
);
dequant_4bit_8_myq
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
z
,
dq
[
2
],
z1z16
[
2
],
y1y16
[
2
],
size_n
,
false
);
dequant_4bit_8_myq
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
dequant_4bit_8_my
g
q
(
load_int4
.
w
,
dq
[
3
],
z1z16
[
3
],
y1y16
[
3
],
size_n
,
false
);
b_ptr
+=
size_n
;
b_ptr
+=
size_n
;
//half* dqh = (half*)dq;
//half* dqh = (half*)dq;
...
@@ -1027,8 +1027,8 @@ __global__ void reconstruct_exllama_3bit_kernel
...
@@ -1027,8 +1027,8 @@ __global__ void reconstruct_exllama_3bit_kernel
(
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
const
int
size_k
,
const
int
size_k
,
const
int
size_n
,
const
int
size_n
,
const
int
groups
,
const
int
groups
,
...
@@ -1036,8 +1036,8 @@ __global__ void reconstruct_exllama_3bit_kernel
...
@@ -1036,8 +1036,8 @@ __global__ void reconstruct_exllama_3bit_kernel
)
)
{
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q3_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q3_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
...
@@ -1071,8 +1071,8 @@ __global__ void reconstruct_exllama_3bit_kernel
...
@@ -1071,8 +1071,8 @@ __global__ void reconstruct_exllama_3bit_kernel
// Initial zeros/scale
// Initial zeros/scale
int
zeros
[
4
];
int
zeros
[
4
];
half2
scales
[
4
];
half2
scales
[
4
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
__syncthreads
();
__syncthreads
();
...
@@ -1085,8 +1085,8 @@ __global__ void reconstruct_exllama_3bit_kernel
...
@@ -1085,8 +1085,8 @@ __global__ void reconstruct_exllama_3bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
}
}
for
(
int
p
=
0
;
p
<
1
;
p
++
)
for
(
int
p
=
0
;
p
<
1
;
p
++
)
...
@@ -1129,8 +1129,8 @@ __global__ void reconstruct_exllama_2bit_kernel
...
@@ -1129,8 +1129,8 @@ __global__ void reconstruct_exllama_2bit_kernel
(
(
const
uint32_t
*
__restrict__
b_q_weight
,
const
uint32_t
*
__restrict__
b_q_weight
,
const
int
*
__restrict__
b_q_perm
,
const
int
*
__restrict__
b_q_perm
,
const
uint32_t
*
__restrict__
b_myq_qzeros
,
const
uint32_t
*
__restrict__
b_my
g
q_qzeros
,
const
half
*
__restrict__
b_myq_scales
,
const
half
*
__restrict__
b_my
g
q_scales
,
const
int
size_k
,
const
int
size_k
,
const
int
size_n
,
const
int
size_n
,
const
int
groups
,
const
int
groups
,
...
@@ -1138,8 +1138,8 @@ __global__ void reconstruct_exllama_2bit_kernel
...
@@ -1138,8 +1138,8 @@ __global__ void reconstruct_exllama_2bit_kernel
)
)
{
{
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_half_rw
b_
(
b
,
size_k
,
size_n
);
MatrixView_q2_row
b_myq_qzeros_
(
b_myq_qzeros
,
groups
,
size_n
);
MatrixView_q2_row
b_my
g
q_qzeros_
(
b_my
g
q_qzeros
,
groups
,
size_n
);
MatrixView_half
b_myq_scales_
(
b_myq_scales
,
groups
,
size_n
);
MatrixView_half
b_my
g
q_scales_
(
b_my
g
q_scales
,
groups
,
size_n
);
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_k
=
BLOCK_KN_SIZE
*
blockIdx
.
y
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
int
offset_n
=
BLOCK_KN_SIZE
*
blockIdx
.
x
*
4
;
...
@@ -1173,8 +1173,8 @@ __global__ void reconstruct_exllama_2bit_kernel
...
@@ -1173,8 +1173,8 @@ __global__ void reconstruct_exllama_2bit_kernel
// Initial zeros/scale
// Initial zeros/scale
int
zeros
[
4
];
int
zeros
[
4
];
half2
scales
[
4
];
half2
scales
[
4
];
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
__syncthreads
();
__syncthreads
();
...
@@ -1187,8 +1187,8 @@ __global__ void reconstruct_exllama_2bit_kernel
...
@@ -1187,8 +1187,8 @@ __global__ void reconstruct_exllama_2bit_kernel
{
{
group
++
;
group
++
;
nextgroup
+=
groupsize
;
nextgroup
+=
groupsize
;
b_myq_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_my
g
q_qzeros_
.
item4
(
zeros
,
group
,
n
);
b_myq_scales_
.
item4_h2
(
scales
,
group
,
n
);
b_my
g
q_scales_
.
item4_h2
(
scales
,
group
,
n
);
}
}
for
(
int
p
=
0
;
p
<
2
;
p
++
)
for
(
int
p
=
0
;
p
<
2
;
p
++
)
...
@@ -1230,8 +1230,8 @@ __global__ void reconstruct_exllama_2bit_kernel
...
@@ -1230,8 +1230,8 @@ __global__ void reconstruct_exllama_2bit_kernel
void
reconstruct_exllama
void
reconstruct_exllama
(
(
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_myq_qzeros
,
const
uint32_t
*
b_my
g
q_qzeros
,
const
half
*
b_myq_scales
,
const
half
*
b_my
g
q_scales
,
const
int
*
b_q_perm
,
const
int
*
b_q_perm
,
half
*
out
,
half
*
out
,
int
height
,
int
height
,
...
@@ -1260,8 +1260,8 @@ void reconstruct_exllama
...
@@ -1260,8 +1260,8 @@ void reconstruct_exllama
(
(
b_q_weight
,
b_q_weight
,
b_q_perm
,
b_q_perm
,
b_myq_qzeros
,
b_my
g
q_qzeros
,
b_myq_scales
,
b_my
g
q_scales
,
height
,
height
,
width
,
width
,
groups
,
groups
,
...
@@ -1461,8 +1461,8 @@ void gemm_half_q_half_alt
...
@@ -1461,8 +1461,8 @@ void gemm_half_q_half_alt
(
(
const
half
*
a
,
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_myq_qzeros
,
const
uint32_t
*
b_my
g
q_qzeros
,
const
half
*
b_myq_scales
,
const
half
*
b_my
g
q_scales
,
const
int
*
b_g_idx
,
const
int
*
b_g_idx
,
half
*
c
,
half
*
c
,
int
size_m
,
int
size_m
,
...
@@ -1490,8 +1490,8 @@ void gemm_half_q_half_alt
...
@@ -1490,8 +1490,8 @@ void gemm_half_q_half_alt
(
const
half2
*
)
a
,
(
const
half2
*
)
a
,
b_q_weight
,
b_q_weight
,
c
,
c
,
b_myq_scales
,
b_my
g
q_scales
,
b_myq_qzeros
,
b_my
g
q_qzeros
,
b_g_idx
,
b_g_idx
,
size_m
,
size_m
,
size_k
/
32
*
bit
,
size_k
/
32
*
bit
,
...
@@ -1500,7 +1500,7 @@ void gemm_half_q_half_alt
...
@@ -1500,7 +1500,7 @@ void gemm_half_q_half_alt
}
}
template
<
class
T
,
int
bit
>
template
<
class
T
,
int
bit
>
__global__
void
reconstruct_myq_kernel
__global__
void
reconstruct_my
g
q_kernel
(
(
const
uint32_t
*
__restrict__
w
,
const
uint32_t
*
__restrict__
w
,
const
half
*
__restrict__
w_scales
,
const
half
*
__restrict__
w_scales
,
...
@@ -1538,7 +1538,7 @@ __global__ void reconstruct_myq_kernel
...
@@ -1538,7 +1538,7 @@ __global__ void reconstruct_myq_kernel
}
}
}
}
__global__
void
reconstruct_myq_3bit_kernel
__global__
void
reconstruct_my
g
q_3bit_kernel
(
(
const
uint32_t
*
__restrict__
w
,
const
uint32_t
*
__restrict__
w
,
const
half
*
__restrict__
w_scales
,
const
half
*
__restrict__
w_scales
,
...
@@ -1589,11 +1589,11 @@ __global__ void reconstruct_myq_3bit_kernel
...
@@ -1589,11 +1589,11 @@ __global__ void reconstruct_myq_3bit_kernel
}
}
}
}
void
reconstruct_myq
void
reconstruct_my
g
q
(
(
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_myq_qzeros
,
const
uint32_t
*
b_my
g
q_qzeros
,
const
half
*
b_myq_scales
,
const
half
*
b_my
g
q_scales
,
const
int
*
b_g_idx
,
const
int
*
b_g_idx
,
half
*
out
,
half
*
out
,
int
height
,
int
height
,
...
@@ -1608,13 +1608,13 @@ void reconstruct_myq
...
@@ -1608,13 +1608,13 @@ void reconstruct_myq
gridDim
.
y
=
DIVIDE
(
height
,
32
/
bit
);
gridDim
.
y
=
DIVIDE
(
height
,
32
/
bit
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
gridDim
.
x
=
DIVIDE
(
width
,
BLOCK_KN_SIZE
);
auto
kernel
=
reconstruct_myq_kernel
<
MatrixView_q4_row
,
4
>
;
auto
kernel
=
reconstruct_my
g
q_kernel
<
MatrixView_q4_row
,
4
>
;
if
(
bit
==
2
)
{
if
(
bit
==
2
)
{
kernel
=
reconstruct_myq_kernel
<
MatrixView_q2_row
,
2
>
;
kernel
=
reconstruct_my
g
q_kernel
<
MatrixView_q2_row
,
2
>
;
}
else
if
(
bit
==
8
)
{
}
else
if
(
bit
==
8
)
{
kernel
=
reconstruct_myq_kernel
<
MatrixView_q8_row
,
8
>
;
kernel
=
reconstruct_my
g
q_kernel
<
MatrixView_q8_row
,
8
>
;
}
else
if
(
bit
==
3
)
{
}
else
if
(
bit
==
3
)
{
kernel
=
reconstruct_myq_3bit_kernel
;
kernel
=
reconstruct_my
g
q_3bit_kernel
;
gridDim
.
y
=
DIVIDE
(
height
,
32
);
gridDim
.
y
=
DIVIDE
(
height
,
32
);
}
}
...
@@ -1622,8 +1622,8 @@ void reconstruct_myq
...
@@ -1622,8 +1622,8 @@ void reconstruct_myq
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
(
b_q_weight
,
b_q_weight
,
b_myq_scales
,
b_my
g
q_scales
,
b_myq_qzeros
,
b_my
g
q_qzeros
,
b_g_idx
,
b_g_idx
,
height
,
height
,
width
,
width
,
...
@@ -1638,8 +1638,8 @@ void gemm_half_q_half_cuda
...
@@ -1638,8 +1638,8 @@ void gemm_half_q_half_cuda
cublasHandle_t
cublas_handle
,
cublasHandle_t
cublas_handle
,
const
half
*
a
,
const
half
*
a
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_q_weight
,
const
uint32_t
*
b_myq_qzeros
,
const
uint32_t
*
b_my
g
q_qzeros
,
const
half
*
b_myq_scales
,
const
half
*
b_my
g
q_scales
,
const
int
*
b_g_idx
,
const
int
*
b_g_idx
,
half
*
c
,
half
*
c
,
half
*
temp_dq
,
half
*
temp_dq
,
...
@@ -1661,12 +1661,12 @@ void gemm_half_q_half_cuda
...
@@ -1661,12 +1661,12 @@ void gemm_half_q_half_cuda
if
(
use_reconstruct
)
{
if
(
use_reconstruct
)
{
// Reconstruct FP16 matrix, then cuBLAS
// Reconstruct FP16 matrix, then cuBLAS
if
(
use_exllama
)
{
if
(
use_exllama
)
{
reconstruct_exllama
(
b_q_weight
,
b_myq_qzeros
,
b_myq_scales
,
b_g_idx
,
temp_dq
,
reconstruct_exllama
(
b_q_weight
,
b_my
g
q_qzeros
,
b_my
g
q_scales
,
b_g_idx
,
temp_dq
,
size_k
,
size_n
,
groups
,
bit
);
size_k
,
size_n
,
groups
,
bit
);
}
}
else
else
{
{
reconstruct_myq
(
b_q_weight
,
b_myq_qzeros
,
b_myq_scales
,
b_g_idx
,
reconstruct_my
g
q
(
b_q_weight
,
b_my
g
q_qzeros
,
b_my
g
q_scales
,
b_g_idx
,
temp_dq
,
size_k
,
size_n
,
groups
,
bit
);
temp_dq
,
size_k
,
size_n
,
groups
,
bit
);
}
}
...
@@ -1689,22 +1689,22 @@ void gemm_half_q_half_cuda
...
@@ -1689,22 +1689,22 @@ void gemm_half_q_half_cuda
if
(
max_chunks
)
if
(
max_chunks
)
{
{
gemm_half_q_half_cuda_part
(
a
,
b_q_weight
,
b_myq_qzeros
,
b_myq_scales
,
b_g_idx
,
gemm_half_q_half_cuda_part
(
a
,
b_q_weight
,
b_my
g
q_qzeros
,
b_my
g
q_scales
,
b_g_idx
,
c
,
last_chunk
,
size_n
,
size_k
,
BLOCK_M_SIZE_MAX
,
c
,
last_chunk
,
size_n
,
size_k
,
BLOCK_M_SIZE_MAX
,
groups
,
bit
);
groups
,
bit
);
}
}
if
(
last_chunk_size
)
if
(
last_chunk_size
)
{
{
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b_q_weight
,
b_myq_qzeros
,
gemm_half_q_half_cuda_part
(
a
+
last_chunk
*
size_k
,
b_q_weight
,
b_my
g
q_qzeros
,
b_myq_scales
,
b_g_idx
,
c
+
last_chunk
*
size_n
,
b_my
g
q_scales
,
b_g_idx
,
c
+
last_chunk
*
size_n
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
last_chunk_size
,
size_n
,
size_k
,
last_chunk_size
,
groups
,
bit
);
groups
,
bit
);
}
}
}
}
else
else
{
{
gemm_half_q_half_alt
(
a
,
b_q_weight
,
b_myq_qzeros
,
b_myq_scales
,
b_g_idx
,
gemm_half_q_half_alt
(
a
,
b_q_weight
,
b_my
g
q_qzeros
,
b_my
g
q_scales
,
b_g_idx
,
c
,
size_m
,
size_n
,
size_k
,
bit
);
c
,
size_m
,
size_n
,
size_k
,
bit
);
}
}
}
}
...
@@ -2020,15 +2020,15 @@ void shuffle_exllama_weight
...
@@ -2020,15 +2020,15 @@ void shuffle_exllama_weight
shuffle_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
q_weight
,
height
,
width
);
shuffle_kernel
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
q_weight
,
height
,
width
);
}
}
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
torch
::
Tensor
myq_gemm
torch
::
Tensor
my
g
q_gemm
(
(
torch
::
Tensor
a
,
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_myq_qzeros
,
torch
::
Tensor
b_my
g
q_qzeros
,
torch
::
Tensor
b_myq_scales
,
torch
::
Tensor
b_my
g
q_scales
,
torch
::
Tensor
b_g_idx
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
bool
use_exllama
,
int
bit
int
bit
...
@@ -2039,27 +2039,27 @@ torch::Tensor myq_gemm
...
@@ -2039,27 +2039,27 @@ torch::Tensor myq_gemm
at
::
Tensor
c
=
torch
::
empty
({
a
.
size
(
0
),
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
c
=
torch
::
empty
({
a
.
size
(
0
),
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
temp_dq
=
torch
::
empty
({
b_q_weight
.
size
(
0
)
*
32
/
bit
,
b_q_weight
.
size
(
1
)},
options
);
at
::
Tensor
temp_dq
=
torch
::
empty
({
b_q_weight
.
size
(
0
)
*
32
/
bit
,
b_q_weight
.
size
(
1
)},
options
);
vllm
::
myq
::
gemm_half_q_half_cuda
vllm
::
my
g
q
::
gemm_half_q_half_cuda
(
(
at
::
cuda
::
getCurrentCUDABlasHandle
(),
at
::
cuda
::
getCurrentCUDABlasHandle
(),
(
const
half
*
)
a
.
data_ptr
(),
(
const
half
*
)
a
.
data_ptr
(),
(
const
uint32_t
*
)
b_q_weight
.
data_ptr
(),
(
const
uint32_t
*
)
b_q_weight
.
data_ptr
(),
(
const
uint32_t
*
)
b_myq_qzeros
.
data_ptr
(),
(
const
uint32_t
*
)
b_my
g
q_qzeros
.
data_ptr
(),
(
const
half
*
)
b_myq_scales
.
data_ptr
(),
(
const
half
*
)
b_my
g
q_scales
.
data_ptr
(),
b_g_idx
.
device
().
is_meta
()
?
NULL
:
(
const
int
*
)
b_g_idx
.
data_ptr
(),
b_g_idx
.
device
().
is_meta
()
?
NULL
:
(
const
int
*
)
b_g_idx
.
data_ptr
(),
(
half
*
)
c
.
data_ptr
(),
(
half
*
)
c
.
data_ptr
(),
(
half
*
)
temp_dq
.
data_ptr
(),
(
half
*
)
temp_dq
.
data_ptr
(),
c
.
size
(
0
),
// m
c
.
size
(
0
),
// m
c
.
size
(
1
),
// n
c
.
size
(
1
),
// n
a
.
size
(
1
),
// k
a
.
size
(
1
),
// k
b_myq_qzeros
.
size
(
0
),
// group number
b_my
g
q_qzeros
.
size
(
0
),
// group number
use_exllama
,
use_exllama
,
bit
bit
);
);
return
c
;
return
c
;
}
}
void
myq_shuffle
void
my
g
q_shuffle
(
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
torch
::
Tensor
q_perm
,
...
@@ -2067,7 +2067,7 @@ void myq_shuffle
...
@@ -2067,7 +2067,7 @@ void myq_shuffle
)
)
{
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
vllm
::
myq
::
shuffle_exllama_weight
(
vllm
::
my
g
q
::
shuffle_exllama_weight
(
(
uint32_t
*
)
q_weight
.
data_ptr
(),
(
uint32_t
*
)
q_weight
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_weight
.
size
(
0
)
*
32
/
bit
,
q_weight
.
size
(
0
)
*
32
/
bit
,
...
...
csrc/quantization/myq/qdq_2.cuh
→
csrc/quantization/my
g
q/qdq_2.cuh
View file @
eb8e460c
...
@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
...
@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
#include "qdq_util.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
// Permutation:
// Permutation:
//
//
...
@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16
...
@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
dq
[
7
]
=
__hfma2
(
q7
.
as_half2
,
y64
,
z64
);
}
}
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/myq/qdq_3.cuh
→
csrc/quantization/my
g
q/qdq_3.cuh
View file @
eb8e460c
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "qdq_util.cuh"
#include "qdq_util.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
// Permutation:
// Permutation:
//
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
// v9997775 55333111 u8886664 44222000 (u, v lsb)
...
@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32
...
@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
dq
[
15
]
=
__hadd2
(
q15
.
as_half2
,
z1
);
}
}
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/myq/qdq_4.cuh
→
csrc/quantization/my
g
q/qdq_4.cuh
View file @
eb8e460c
...
@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
...
@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
#include "qdq_util.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
// Permutation:
// Permutation:
//
//
// 77775555 33331111 66664444 22220000
// 77775555 33331111 66664444 22220000
...
@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
...
@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
}
}
__forceinline__
__device__
void
dequant_4bit_8_myq
__forceinline__
__device__
void
dequant_4bit_8_my
g
q
(
(
const
uint32_t
q_0
,
const
uint32_t
q_0
,
half2
(
&
dq
)[
4
],
half2
(
&
dq
)[
4
],
...
@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq
...
@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
dq
[
3
]
=
__hfma2
(
q3
.
as_half2
,
y1y16
[
1
],
z1z16
[
1
]);
// half2( q[6] - z, q[7] - z )
}
}
}
}
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/myq/qdq_8.cuh
→
csrc/quantization/my
g
q/qdq_8.cuh
View file @
eb8e460c
...
@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
...
@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
#include "qdq_util.cuh"
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
__forceinline__
__device__
void
shuffle_8bit_4
__forceinline__
__device__
void
shuffle_8bit_4
(
(
...
@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8
...
@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
dq
[
i
]
=
__halves2half2
(
dqh
[
i
*
2
],
dqh
[
i
*
2
+
1
]);
}
}
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
csrc/quantization/myq/qdq_util.cuh
→
csrc/quantization/my
g
q/qdq_util.cuh
View file @
eb8e460c
...
@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
...
@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _qdq_util_cuh
#define _qdq_util_cuh
namespace
vllm
{
namespace
vllm
{
namespace
myq
{
namespace
my
g
q
{
union
half2_uint32
union
half2_uint32
{
{
...
@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
...
@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
return
(
int
)(
__funnelshift_rc
(
q0
,
q1
,
shift
)
&
mask
);
}
}
}
// namespace myq
}
// namespace my
g
q
}
// namespace vllm
}
// namespace vllm
#endif
#endif
setup.py
View file @
eb8e460c
...
@@ -339,7 +339,7 @@ vllm_extension_sources = [
...
@@ -339,7 +339,7 @@ vllm_extension_sources = [
"csrc/layernorm_kernels.cu"
,
"csrc/layernorm_kernels.cu"
,
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
,
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
,
"csrc/quantization/gptq/q_gemm.cu"
,
"csrc/quantization/gptq/q_gemm.cu"
,
"csrc/quantization/myq/q_gemm.cu"
,
"csrc/quantization/my
g
q/q_gemm.cu"
,
"csrc/cuda_utils_kernels.cu"
,
"csrc/cuda_utils_kernels.cu"
,
"csrc/moe_align_block_size_kernels.cu"
,
"csrc/moe_align_block_size_kernels.cu"
,
"csrc/pybind.cpp"
,
"csrc/pybind.cpp"
,
...
...
vllm/config.py
View file @
eb8e460c
...
@@ -155,7 +155,7 @@ class ModelConfig:
...
@@ -155,7 +155,7 @@ class ModelConfig:
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
,
"marlin"
,
"myq"
]
supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
,
"marlin"
,
"my
g
q"
]
rocm_not_supported_quantization
=
[
"awq"
,
"marlin"
]
rocm_not_supported_quantization
=
[
"awq"
,
"marlin"
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
vllm/engine/arg_utils.py
View file @
eb8e460c
...
@@ -208,7 +208,7 @@ class EngineArgs:
...
@@ -208,7 +208,7 @@ class EngineArgs:
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
type
=
str
,
type
=
str
,
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
'myq'
,
None
],
choices
=
[
'awq'
,
'gptq'
,
'squeezellm'
,
'my
g
q'
,
None
],
default
=
EngineArgs
.
quantization
,
default
=
EngineArgs
.
quantization
,
help
=
'Method used to quantize the weights. If '
help
=
'Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'None, we first check the `quantization_config` '
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
eb8e460c
...
@@ -3,13 +3,13 @@ from typing import Type
...
@@ -3,13 +3,13 @@ from typing import Type
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.base_config
import
QuantizationConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.myq
import
MYQConfig
from
vllm.model_executor.layers.quantization.my
g
q
import
MYQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY
=
{
_QUANTIZATION_CONFIG_REGISTRY
=
{
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"myq"
:
MYQConfig
,
"my
g
q"
:
MYQConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
...
...
vllm/model_executor/layers/quantization/myq.py
→
vllm/model_executor/layers/quantization/my
g
q.py
View file @
eb8e460c
...
@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig):
...
@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
return
"myq"
return
"my
g
q"
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
...
@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
...
@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
else
:
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
myq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
ops
.
my
g
q_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
self
.
quant_config
.
weight_bits
)
self
.
quant_config
.
weight_bits
)
output
=
ops
.
myq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
output
=
ops
.
my
g
q_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment