Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
69810521
Commit
69810521
authored
Mar 27, 2023
by
Tim Dettmers
Browse files
Some small changes.
parent
6c31a5fe
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
135 additions
and
87 deletions
+135
-87
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+6
-2
bitsandbytes/utils.py
bitsandbytes/utils.py
+40
-0
csrc/kernels.cu
csrc/kernels.cu
+2
-0
csrc/ops.cu
csrc/ops.cu
+2
-0
tests/test_functional.py
tests/test_functional.py
+85
-85
No files found.
bitsandbytes/nn/modules.py
View file @
69810521
...
...
@@ -173,10 +173,11 @@ class FP4Params(torch.nn.Parameter):
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
weight
=
FP4Params
(
self
.
weight
.
data
,
requires_grad
=
False
)
self
.
compute_dtype
=
compute_dtype
def
init_8bit_state
(
self
):
pass
...
...
@@ -191,9 +192,12 @@ class LinearFP4(nn.Linear):
if
getattr
(
self
.
weight
,
'quant_state'
,
None
)
is
None
:
print
(
'FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.'
)
inp_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float16
)
if
self
.
compute_dtype
is
not
None
:
x
=
x
.
to
(
self
.
compute_dtype
)
bias
=
None
if
self
.
bias
is
None
else
self
.
bias
.
half
()
out
=
bnb
.
matmul_fp4
(
x
,
self
.
weight
.
t
(),
bias
=
bias
,
quant_state
=
self
.
weight
.
quant_state
)
out
=
out
.
to
(
inp_dtype
)
return
out
...
...
bitsandbytes/utils.py
View file @
69810521
...
...
@@ -21,3 +21,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
std_out
,
std_err
=
execute_and_return_decoded_std_streams
(
command_string
)
return
std_out
,
std_err
def
replace_linear
(
model
,
linear_replacement
,
skip_modules
=
[
"lm_head"
],
copy_weights
=
False
,
post_processing_function
=
None
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for
name
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
replace_linear
(
module
,
linear_replacement
,
skip_modules
,
copy_weights
,
post_processing_function
)
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
name
not
in
skip_modules
:
old_module
=
model
.
_modules
[
name
]
model
.
_modules
[
name
]
=
linear_replacement
(
module
.
in_features
,
module
.
out_features
,
module
.
bias
is
not
None
,
)
if
copy_weights
:
model
.
_modules
[
name
].
weight
=
old_module
.
weight
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
return
model
csrc/kernels.cu
View file @
69810521
...
...
@@ -2968,6 +2968,8 @@ template __global__ void kQuantizeBlockwise<half, 128, 2, 0, 1>(float * code, ha
template
__global__
void
kQuantizeBlockwise
<
float
,
128
,
2
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
half
,
64
,
2
,
0
,
1
>(
float
*
code
,
half
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
template
__global__
void
kQuantizeBlockwise
<
float
,
64
,
2
,
0
,
1
>(
float
*
code
,
float
*
__restrict__
const
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
__restrict__
const
rand
,
const
int
rand_offset
,
const
int
n
);
//template __global__ void kQuantizeBlockwise<half, 64, 1, 0, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
//template __global__ void kQuantizeBlockwise<float, 64, 1, 0, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template
__global__
void
kDequantizeBlockwise
<
half
,
512
,
64
,
8
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
half
*
out
,
const
int
blocksize
,
const
int
n
);
template
__global__
void
kDequantizeBlockwise
<
float
,
512
,
64
,
8
,
1
>(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
blocksize
,
const
int
n
);
...
...
csrc/ops.cu
View file @
69810521
...
...
@@ -71,6 +71,8 @@ template <typename T, int STOCHASTIC, int FP4> void quantizeBlockwise(float * co
kQuantizeBlockwise
<
T
,
128
,
2
,
0
,
FP4
><<<
num_blocks
,
64
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
else
if
(
blocksize
==
64
)
kQuantizeBlockwise
<
T
,
64
,
2
,
0
,
FP4
><<<
num_blocks
,
32
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
//else if(blocksize == 32)
//kQuantizeBlockwise<T, 32, 1, 0, FP4><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
...
...
tests/test_functional.py
View file @
69810521
...
...
@@ -1784,17 +1784,17 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
1
seqdim
=
1
batch_size
=
4
seqdim
=
256
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
#
values.append((batch_size, seqdim, 1024, 4*1024))
#
values.append((batch_size, seqdim, 1536, 4*1536))
#
values.append((batch_size, seqdim, 2048, 4*2048))
#
values.append((batch_size, seqdim, 2560, 4*2560))
#
values.append((batch_size, seqdim, 4096, 4*4096))
#
values.append((batch_size, seqdim, 5140, 4*5140))
#
values.append((batch_size, seqdim, 12288, 4*12288))
values
.
append
((
batch_size
,
seqdim
,
1024
,
4
*
1024
))
values
.
append
((
batch_size
,
seqdim
,
1536
,
4
*
1536
))
values
.
append
((
batch_size
,
seqdim
,
2048
,
4
*
2048
))
values
.
append
((
batch_size
,
seqdim
,
2560
,
4
*
2560
))
values
.
append
((
batch_size
,
seqdim
,
4096
,
4
*
4096
))
values
.
append
((
batch_size
,
seqdim
,
5140
,
4
*
5140
))
values
.
append
((
batch_size
,
seqdim
,
12288
,
4
*
12288
))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
...
...
@@ -1839,90 +1839,90 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
print
(
f
"CB -> CxB conversion (each iteration): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
bnb.matmul(A, B)
#
torch.cuda.synchronize()
#
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul
(
A
,
B
,
threshold
=
6.0
)
torch
.
cuda
.
synchronize
()
print
(
f
"CB -> CxB conversion + threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
bnb.matmul(A, B, threshold=6.0)
#
torch.cuda.synchronize()
#
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
0.0
)
C32A
,
SA
=
F
.
transform
(
CA
,
"col32"
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
print
(
f
"no overhead matmul-lt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
#
C32A, SA = F.transform(CA, "col32")
#
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
#
CxB, SB = F.transform(CB, to_order=formatB)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#
torch.cuda.synchronize()
#
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
torch
.
cuda
.
synchronize
()
print
(
f
"vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
BA, statsB = F.vectorwise_quant(B, dim=1)
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
A2 = A.view(-1, A.shape[-1]).contiguous()
#
CA, statsA = F.vectorwise_quant(A2, dim=1)
#
C32A, SA = F.nvidia_transform(CA, "col32")
#
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
#
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
#
torch.cuda.synchronize()
#
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"linear"
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
"linear"
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
"col32"
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
torch
.
cuda
.
synchronize
()
print
(
f
"linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
#
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
A2 = A.view(-1, A.shape[-1]).contiguous()
#
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
#
C32A, SA = F.nvidia_transform(CA, "col32")
#
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
#
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
#
out = Cout * statsB * statsA * (1.0 / (127 * 127))
#
torch.cuda.synchronize()
#
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb linear8bitlt (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
linear8bit(A)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
linear8bit(A)
#
torch.cuda.synchronize()
#
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb linear8bitlt with threshold (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
linearMixedBit(A)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
linearMixedBit(A)
#
torch.cuda.synchronize()
#
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit_train
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
linear8bit_train
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb linear8bitlt (training): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
linear8bit_train(A)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
linear8bit_train(A)
#
torch.cuda.synchronize()
#
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit_train_thresh
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
linear8bit_train
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb linear8bitlt with threshold (training): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#
linear8bit_train_thresh(A)
#
torch.cuda.synchronize()
#
t0 = time.time()
#
for i in range(iters):
#
linear8bit_train(A)
#
torch.cuda.synchronize()
#
print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def
test_zeropoint
():
def
quant_zp
(
x
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment