Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
758c7175
"vscode:/vscode.git/clone" did not exist on "dec8b49b5df9428bda561f82780b2d73f4589ea9"
Commit
758c7175
authored
Aug 04, 2022
by
Tim Dettmers
Browse files
Merge branch 'debug' into cuda-bin-switch-and-cli
parents
96bc209b
ab72a129
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
180 additions
and
199 deletions
+180
-199
Makefile
Makefile
+1
-1
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+15
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+81
-136
csrc/ops.cu
csrc/ops.cu
+73
-53
tests/test_autograd.py
tests/test_autograd.py
+10
-7
No files found.
Makefile
View file @
758c7175
bitsandbytes/autograd/_functions.py
View file @
758c7175
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
torch
import
torch
import
math
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
...
@@ -199,6 +199,17 @@ class MatmulLtState:
...
@@ -199,6 +199,17 @@ class MatmulLtState:
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# default to pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
math
.
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
if
A
.
shape
[
-
1
]
==
B
.
shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[
1
:],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
# 1. Quantize A
# 1. Quantize A
# 2. Quantize B
# 2. Quantize B
# 3. Matmul
# 3. Matmul
...
@@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
req_gradA
,
req_gradB
=
ctx
.
req_grads
req_gradA
,
req_gradB
=
ctx
.
req_grads
CAt
,
subA
=
ctx
.
tensors
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
...
@@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
grad_shape
ctx
.
grad_shape
)
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
matmul
=
MatMul8bitLt
.
apply
matmul
=
MatMul8bitLt
.
apply
...
...
bitsandbytes/functional.py
View file @
758c7175
...
@@ -4,9 +4,10 @@
...
@@ -4,9 +4,10 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
ctypes
as
ct
import
ctypes
as
ct
import
random
import
random
from
typing
import
Tuple
import
math
import
torch
import
torch
from
typing
import
Tuple
from
torch
import
Tensor
from
torch
import
Tensor
from
.cextension
import
COMPILED_WITH_CUDA
,
lib
from
.cextension
import
COMPILED_WITH_CUDA
,
lib
...
@@ -193,6 +194,14 @@ def get_special_format_str():
...
@@ -193,6 +194,14 @@ def get_special_format_str():
return
"col_turing"
return
"col_turing"
def
is_on_gpu
(
tensors
):
on_gpu
=
True
for
t
in
tensors
:
if
t
is
None
:
continue
# NULL pointers are fine
on_gpu
&=
t
.
device
.
type
==
'cuda'
return
on_gpu
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
"""
"""
Get the ctypes pointer from a PyTorch Tensor.
Get the ctypes pointer from a PyTorch Tensor.
...
@@ -336,7 +345,7 @@ def nvidia_transform(
...
@@ -336,7 +345,7 @@ def nvidia_transform(
def
estimate_quantiles
(
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
)
->
Tensor
:
"""
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
Estimates 256 equidistant quantiles on the input tensor eCDF.
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
...
@@ -361,9 +370,9 @@ def estimate_quantiles(
...
@@ -361,9 +370,9 @@ def estimate_quantiles(
-------
-------
torch.Tensor:
torch.Tensor:
The 256 quantiles in float32 datatype.
The 256 quantiles in float32 datatype.
"""
'''
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
out
=
torch
.
zeros
((
256
,),
dtype
=
torch
.
float32
,
device
=
A
.
device
)
is_on_gpu
([
A
,
out
]
)
if
A
.
dtype
==
torch
.
float32
:
if
A
.
dtype
==
torch
.
float32
:
lib
.
cestimate_quantiles_fp32
(
lib
.
cestimate_quantiles_fp32
(
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
())
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_float
(
offset
),
ct
.
c_int
(
A
.
numel
())
...
@@ -428,7 +437,8 @@ def quantize_blockwise(
...
@@ -428,7 +437,8 @@ def quantize_blockwise(
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
if
A
.
device
.
type
!=
"cpu"
:
if
A
.
device
.
type
!=
'cpu'
:
is_on_gpu
([
code
,
A
,
absmax
,
out
,
rand
])
if
rand
is
not
None
:
if
rand
is
not
None
:
assert
rand
.
numel
()
>=
1024
assert
rand
.
numel
()
>=
1024
rand_offset
=
random
.
randint
(
0
,
1023
)
rand_offset
=
random
.
randint
(
0
,
1023
)
...
@@ -541,7 +551,8 @@ def dequantize_blockwise(
...
@@ -541,7 +551,8 @@ def dequantize_blockwise(
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
f
"The blockwise of
{
blocksize
}
is not supported. Supported values: [2048 4096]"
)
)
if
A
.
device
.
type
!=
"cpu"
:
if
A
.
device
.
type
!=
'cpu'
:
is_on_gpu
([
A
,
out
])
if
out
.
dtype
==
torch
.
float32
:
if
out
.
dtype
==
torch
.
float32
:
lib
.
cdequantize_blockwise_fp32
(
lib
.
cdequantize_blockwise_fp32
(
get_ptr
(
quant_state
[
1
]),
get_ptr
(
quant_state
[
1
]),
...
@@ -610,7 +621,7 @@ def dequantize(
...
@@ -610,7 +621,7 @@ def dequantize(
def
quantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
def
quantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
'''
Quantizes input tensor to 8-bit.
Quantizes input tensor to 8-bit.
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
...
@@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
...
@@ -629,15 +640,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
-------
-------
torch.Tensor:
torch.Tensor:
Quantized 8-bit tensor.
Quantized 8-bit tensor.
"""
'''
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
uint8
)
is_on_gpu
([
A
,
out
]
)
lib
.
cquantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
lib
.
cquantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
return
out
return
out
def
dequantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
def
dequantize_no_absmax
(
A
:
Tensor
,
code
:
Tensor
,
out
:
Tensor
=
None
)
->
Tensor
:
"""
'''
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor to 32-bit.
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
...
@@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
...
@@ -656,12 +667,10 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
-------
-------
torch.Tensor:
torch.Tensor:
32-bit output tensor.
32-bit output tensor.
"""
'''
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
is_on_gpu
([
code
,
A
,
out
])
lib
.
cdequantize
(
lib
.
cdequantize
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
()))
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int
(
A
.
numel
())
)
return
out
return
out
...
@@ -983,6 +992,7 @@ def percentile_clipping(
...
@@ -983,6 +992,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms).
The current optimiation steps (number of past gradient norms).
"""
"""
is_on_gpu
([
grad
,
gnorm_vec
])
if
grad
.
dtype
==
torch
.
float32
:
if
grad
.
dtype
==
torch
.
float32
:
lib
.
cpercentile_clipping_g32
(
lib
.
cpercentile_clipping_g32
(
get_ptr
(
grad
),
get_ptr
(
grad
),
...
@@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
...
@@ -1027,21 +1037,11 @@ def histogram_scatter_add_2d(
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
n
=
ct
.
c_int32
(
index1
.
numel
())
n
=
ct
.
c_int32
(
index1
.
numel
())
lib
.
chistogram_scatter_add_2d
(
is_on_gpu
([
histogram
,
index1
,
index2d
,
source
])
get_ptr
(
histogram
),
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
)
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
,
)
def
check_matmul
(
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
):
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
raise
TypeError
(
raise
TypeError
(
f
"Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
"
f
"Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
"
...
@@ -1213,20 +1213,9 @@ def igemm(
...
@@ -1213,20 +1213,9 @@ def igemm(
# B^T @ A^T = C^T
# B^T @ A^T = C^T
# [km, nk -> mn]
# [km, nk -> mn]
lib
.
cigemm
(
is_on_gpu
([
B
,
A
,
out
])
ptr
,
lib
.
cigemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
ct
.
c_bool
(
transposed_B
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
))
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
)
return
out
return
out
...
@@ -1306,24 +1295,10 @@ def batched_igemm(
...
@@ -1306,24 +1295,10 @@ def batched_igemm(
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
lib
.
cbatched_igemm
(
is_on_gpu
([
B
,
A
,
out
])
ptr
,
lib
.
cbatched_igemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
ct
.
c_bool
(
transposed_B
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
))
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
),
)
return
out
return
out
...
@@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
...
@@ -1332,15 +1307,20 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeB
=
SB
[
0
]
shapeB
=
SB
[
0
]
dimsA
=
len
(
shapeA
)
dimsA
=
len
(
shapeA
)
dimsB
=
len
(
shapeB
)
dimsB
=
len
(
shapeB
)
assert
dimsB
==
2
,
'Only two dimensional matrices are supported for argument B'
if
dimsA
==
2
:
if
dimsA
==
2
:
m
=
shapeA
[
0
]
m
=
shapeA
[
0
]
elif
dimsA
==
3
:
elif
dimsA
==
3
:
m
=
shapeA
[
0
]
*
shapeA
[
1
]
m
=
shapeA
[
0
]
*
shapeA
[
1
]
if
dimsB
==
2
:
rows
=
n
=
shapeB
[
0
]
rows
=
n
=
shapeB
[
0
]
elif
dimsB
==
3
:
assert
math
.
prod
(
list
(
shapeA
))
>
0
,
f
'Input tensor dimensions need to be > 0:
{
shapeA
}
'
rows
=
n
=
shapeB
[
0
]
*
shapeB
[
1
]
# if the tensor is empty, return a transformed empty tensor with the right dimensions
if
shapeA
[
0
]
==
0
and
dimsA
==
2
:
return
torch
.
empty
((
0
,
shapeB
[
0
]),
device
=
A
.
device
,
dtype
=
torch
.
float16
)
elif
shapeA
[
1
]
==
0
and
dimsA
==
3
:
return
torch
.
empty
(
tuple
(
shapeA
[:
2
]
+
[
shapeB
[
0
]]),
device
=
A
.
device
,
dtype
=
torch
.
float16
)
if
dimsA
==
2
and
out
is
None
:
if
dimsA
==
2
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
(
out
,
Sout
=
get_transform_buffer
(
...
@@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
...
@@ -1390,7 +1370,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
has_error
=
0
has_error
=
0
ptrRowScale
=
get_ptr
(
None
)
ptrRowScale
=
get_ptr
(
None
)
if
formatB
==
"col_turing"
:
is_on_gpu
([
A
,
B
,
out
])
if
formatB
==
'col_turing'
:
if
dtype
==
torch
.
int32
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_turing_32
(
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
...
@@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
...
@@ -1410,7 +1391,8 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
)
)
if
has_error
==
1
:
if
has_error
==
1
:
raise
Exception
(
"cublasLt ran into an error!"
)
print
(
f
'A:
{
shapeA
}
, B:
{
shapeB
}
, C:
{
Sout
[
0
]
}
; (lda, ldb, ldc):
{
(
lda
,
ldb
,
ldc
)
}
; (m, n, k):
{
(
m
,
n
,
k
)
}
'
)
raise
Exception
(
'cublasLt ran into an error!'
)
torch
.
cuda
.
set_device
(
prev_device
)
torch
.
cuda
.
set_device
(
prev_device
)
...
@@ -1457,16 +1439,8 @@ def mm_dequant(
...
@@ -1457,16 +1439,8 @@ def mm_dequant(
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
lib
.
cdequant_mm_int32_fp16
(
is_on_gpu
([
A
,
row_stats
,
col_stats
,
out
,
new_row_stats
,
new_col_stats
])
ptrA
,
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
)
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
,
)
return
out
return
out
...
@@ -1507,15 +1481,8 @@ def get_colrow_absmax(
...
@@ -1507,15 +1481,8 @@ def get_colrow_absmax(
cols
=
ct
.
c_int32
(
cols
)
cols
=
ct
.
c_int32
(
cols
)
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
lib
.
cget_col_row_stats
(
is_on_gpu
([
A
,
row_stats
,
col_stats
,
nnz_block_ptr
])
ptrA
,
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
)
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
,
)
post_call
(
prev_device
)
post_call
(
prev_device
)
if
threshold
>
0.0
:
if
threshold
>
0.0
:
...
@@ -1642,6 +1609,7 @@ def double_quant(
...
@@ -1642,6 +1609,7 @@ def double_quant(
ptrOutCol
=
get_ptr
(
out_col
)
ptrOutCol
=
get_ptr
(
out_col
)
ptrOutRow
=
get_ptr
(
out_row
)
ptrOutRow
=
get_ptr
(
out_row
)
is_on_gpu
([
A
,
col_stats
,
row_stats
,
out_col
,
out_row
])
if
threshold
>
0.0
:
if
threshold
>
0.0
:
nnz
=
nnz_row_ptr
[
-
1
].
item
()
nnz
=
nnz_row_ptr
[
-
1
].
item
()
if
nnz
>
0
:
if
nnz
>
0
:
...
@@ -1714,33 +1682,19 @@ def get_special_format_str():
...
@@ -1714,33 +1682,19 @@ def get_special_format_str():
)
)
assert
major
>=
7
assert
major
>=
7
if
major
==
7
:
if
major
==
7
:
return
'col_turing'
return
"col_turing"
elif
major
==
8
:
return
'col_ampere'
elif
major
==
8
:
else
:
return
'col_turing'
return
"col_ampere"
else
:
return
"col_turing"
def
transform
(
A
,
def
transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
to_order
,
prev_device
=
pre_call
(
A
.
device
)
from_order
=
"row"
,
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
out
=
None
,
else
:
from_order
=
state
[
1
]
transpose
=
False
,
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
state
=
None
,
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
ld
=
None
,
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
shape
=
state
[
0
]
shape
=
state
[
0
]
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
...
@@ -1752,7 +1706,8 @@ def transform(
...
@@ -1752,7 +1706,8 @@ def transform(
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
if
to_order
==
"col32"
:
is_on_gpu
([
A
,
out
])
if
to_order
==
'col32'
:
if
transpose
:
if
transpose
:
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
...
@@ -1773,9 +1728,9 @@ def transform(
...
@@ -1773,9 +1728,9 @@ def transform(
elif
from_order
==
"col_ampere"
:
elif
from_order
==
"col_ampere"
:
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
'Transform function not implemented: From
{
from_order
}
to
{
to_order
}
'
)
f
"Transform function not implemented: From
{
from_order
}
to
{
to_order
}
"
)
post_call
(
prev_device
)
return
out
,
new_state
return
out
,
new_state
...
@@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
...
@@ -1810,21 +1765,8 @@ def spmm_coo(cooA, B, out=None):
cldb
=
ct
.
c_int32
(
ldb
)
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
cldc
=
ct
.
c_int32
(
ldc
)
lib
.
cspmm_coo
(
is_on_gpu
([
cooA
.
rowidx
,
cooA
.
colidx
,
cooA
.
values
,
B
,
out
])
ptr
,
lib
.
cspmm_coo
(
ptr
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
))
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
),
)
return
out
return
out
...
@@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
...
@@ -1875,6 +1817,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
# print(cooA.rowidx[:64])
# print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0])
# print(cooA.colidx[:64].sort()[0])
is_on_gpu
([
cooA
.
rowidx
,
cooA
.
colidx
,
cooA
.
values
,
B
,
out
,
dequant_stats
])
if
B
.
dtype
==
torch
.
float16
:
if
B
.
dtype
==
torch
.
float16
:
lib
.
cspmm_coo_very_sparse_naive_fp16
(
lib
.
cspmm_coo_very_sparse_naive_fp16
(
ptrMaxCount
,
ptrMaxCount
,
...
@@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
...
@@ -2061,9 +2004,11 @@ def extract_outliers(A, SA, idx):
ptrIdx
=
get_ptr
(
idx
)
ptrIdx
=
get_ptr
(
idx
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
if
formatA
==
"col_turing"
:
prev_device
=
pre_call
(
A
.
device
)
if
formatA
==
'col_turing'
:
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
elif
formatA
==
"col_ampere"
:
elif
formatA
==
"col_ampere"
:
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
post_call
(
prev_device
)
return
out
return
out
csrc/ops.cu
View file @
758c7175
...
@@ -19,53 +19,59 @@ using std::endl;
...
@@ -19,53 +19,59 @@ using std::endl;
void
histogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
)
void
histogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
)
{
{
int
threads
=
512
;
int
threads
=
512
;
int
blocks
=
n
/
threads
;
int
num_blocks
=
n
/
threads
;
blocks
=
n
%
threads
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
threads
==
0
?
num_blocks
:
num_blocks
+
1
;
kHistogramScatterAdd2D
<<<
blocks
,
512
>>>
(
histogram
,
index1
,
index2
,
src
,
maxidx1
,
n
);
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
kHistogramScatterAdd2D
<<<
num_blocks
,
512
>>>
(
histogram
,
index1
,
index2
,
src
,
maxidx1
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
template
<
typename
T
>
void
estimateQuantiles
(
T
*
A
,
float
*
code
,
float
offset
,
int
n
)
template
<
typename
T
>
void
estimateQuantiles
(
T
*
A
,
float
*
code
,
float
offset
,
int
n
)
{
{
int
blocks
=
n
/
4096
;
int
num_blocks
=
n
/
4096
;
blocks
=
n
%
4096
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
4096
==
0
?
num_blocks
:
num_blocks
+
1
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
CUDA_CHECK_RETURN
(
cudaMemset
(
code
,
0
,
256
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
code
,
0
,
256
*
sizeof
(
float
)));
kEstimateQuantiles
<
T
><<<
blocks
,
512
>>>
(
A
,
code
,
offset
,
std
::
numeric_limits
<
T
>::
max
(),
n
);
kEstimateQuantiles
<
T
><<<
num_
blocks
,
512
>>>
(
A
,
code
,
offset
,
std
::
numeric_limits
<
T
>::
max
(),
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
void
quantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
)
void
quantize
(
float
*
code
,
float
*
A
,
unsigned
char
*
out
,
int
n
)
{
{
int
blocks
=
n
/
1024
;
int
num_blocks
=
n
/
1024
;
blocks
=
n
%
1024
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
1024
==
0
?
num_blocks
:
num_blocks
+
1
;
kQuantize
<<<
blocks
,
1024
>>>
(
code
,
A
,
out
,
n
);
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
kQuantize
<<<
num_blocks
,
1024
>>>
(
code
,
A
,
out
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
void
dequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
)
void
dequantize
(
float
*
code
,
unsigned
char
*
A
,
float
*
out
,
int
n
)
{
{
int
blocks
=
n
/
1024
;
int
num_blocks
=
n
/
1024
;
blocks
=
n
%
1024
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
1024
==
0
?
num_blocks
:
num_blocks
+
1
;
kDequantize
<<<
blocks
,
1024
>>>
(
code
,
A
,
out
,
n
);
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
kDequantize
<<<
num_blocks
,
1024
>>>
(
code
,
A
,
out
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
template
<
typename
T
,
int
STOCHASTIC
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
)
template
<
typename
T
,
int
STOCHASTIC
>
void
quantizeBlockwise
(
float
*
code
,
T
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
float
*
rand
,
int
rand_offset
,
const
int
n
)
{
{
int
blocks
=
n
/
4096
;
int
num_blocks
=
n
/
4096
;
blocks
=
n
%
4096
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
4096
==
0
?
num_blocks
:
num_blocks
+
1
;
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
><<<
blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
kQuantizeBlockwise
<
T
,
4096
,
4
,
STOCHASTIC
><<<
num_blocks
,
1024
>>>
(
code
,
A
,
absmax
,
out
,
rand
,
rand_offset
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
template
<
typename
T
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
blocksize
,
const
int
n
)
template
<
typename
T
>
void
dequantizeBlockwise
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
T
*
out
,
int
blocksize
,
const
int
n
)
{
{
int
blocks
=
n
/
blocksize
;
int
num_blocks
=
n
/
blocksize
;
blocks
=
n
%
blocksize
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
blocksize
==
0
?
num_blocks
:
num_blocks
+
1
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
if
(
blocksize
==
4096
)
if
(
blocksize
==
4096
)
kDequantizeBlockwise
<
T
,
4096
,
1024
,
4
><<<
blocks
,
4096
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
kDequantizeBlockwise
<
T
,
4096
,
1024
,
4
><<<
num_
blocks
,
4096
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
else
if
(
blocksize
==
2048
)
else
if
(
blocksize
==
2048
)
kDequantizeBlockwise
<
T
,
2048
,
512
,
4
><<<
blocks
,
2048
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
kDequantizeBlockwise
<
T
,
2048
,
512
,
4
><<<
num_
blocks
,
2048
/
4
>>>
(
code
,
A
,
absmax
,
out
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
@@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
...
@@ -74,18 +80,19 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
eps
,
const
float
weight_decay
,
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
bool
skip_zeros
,
const
int
n
)
const
int
step
,
const
float
lr
,
const
float
gnorm_scale
,
bool
skip_zeros
,
const
int
n
)
{
{
int
blocks
=
n
/
4096
;
int
num_blocks
=
n
/
4096
;
blocks
=
n
%
4096
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
4096
==
0
?
num_blocks
:
num_blocks
+
1
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
switch
(
OPTIMIZER
)
switch
(
OPTIMIZER
)
{
{
case
ADAM
:
case
ADAM
:
if
(
max_unorm
>
0.0
f
)
if
(
max_unorm
>
0.0
f
)
{
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit2State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
blocks
,
512
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kPreconditionOptimizer32bit2State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_
blocks
,
512
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
kOptimizer32bit2State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
kOptimizer32bit2State
<
T
,
OPTIMIZER
><<<
num_
blocks
,
1024
>>>
(
g
,
p
,
state1
,
state2
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
case
MOMENTUM
:
case
MOMENTUM
:
...
@@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
...
@@ -95,11 +102,11 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if
(
max_unorm
>
0.0
f
)
if
(
max_unorm
>
0.0
f
)
{
{
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
kPreconditionOptimizer32bit1State
<
T
,
OPTIMIZER
,
4096
,
8
><<<
num_
blocks
,
512
>>>
(
g
,
p
,
state1
,
unorm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
kOptimizer32bit1State
<
T
,
OPTIMIZER
><<<
num_
blocks
,
1024
>>>
(
g
,
p
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
weight_decay
,
step
,
lr
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
}
}
...
@@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
...
@@ -115,8 +122,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float
weight_decay
,
float
weight_decay
,
const
float
gnorm_scale
,
int
n
)
const
float
gnorm_scale
,
int
n
)
{
{
int
blocks
=
n
/
4096
;
int
num_blocks
=
n
/
4096
;
blocks
=
n
%
4096
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
4096
==
0
?
num_blocks
:
num_blocks
+
1
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
if
(
max_unorm
>
0.0
f
){
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
}
if
(
max_unorm
>
0.0
f
){
CUDA_CHECK_RETURN
(
cudaMemset
(
unorm
,
0
,
1
*
sizeof
(
float
)));
}
...
@@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
...
@@ -125,9 +133,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case
ADAM
:
case
ADAM
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max2
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max2
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit2State
<
T
,
OPTIMIZER
><<<
blocks
,
256
>>>
(
p
,
g
,
state1
,
state2
,
unorm
,
beta1
,
beta2
,
eps
,
step
,
quantiles1
,
quantiles2
,
max1
,
max2
,
new_max1
,
new_max2
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit2State
<
T
,
OPTIMIZER
><<<
num_
blocks
,
256
>>>
(
p
,
g
,
state1
,
state2
,
unorm
,
beta1
,
beta2
,
eps
,
step
,
quantiles1
,
quantiles2
,
max1
,
max2
,
new_max1
,
new_max2
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
kOptimizerStatic8bit2State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
p
,
g
,
state1
,
state2
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
step
,
lr
,
kOptimizerStatic8bit2State
<
T
,
OPTIMIZER
><<<
num_
blocks
,
1024
>>>
(
p
,
g
,
state1
,
state2
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
quantiles2
,
max1
,
max2
,
new_max1
,
new_max2
,
weight_decay
,
gnorm_scale
,
n
);
quantiles1
,
quantiles2
,
max1
,
max2
,
new_max1
,
new_max2
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
...
@@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
...
@@ -135,9 +143,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
case
ADAGRAD
:
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
new_max1
,
0
,
1
*
sizeof
(
float
)));
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
kPreconditionOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_
blocks
,
256
>>>
(
p
,
g
,
state1
,
unorm
,
beta1
,
eps
,
step
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
step
,
lr
,
kOptimizerStatic8bit1State
<
T
,
OPTIMIZER
><<<
num_
blocks
,
1024
>>>
(
p
,
g
,
state1
,
unorm
,
max_unorm
,
param_norm
,
beta1
,
eps
,
step
,
lr
,
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
quantiles1
,
max1
,
new_max1
,
weight_decay
,
gnorm_scale
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
...
@@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
...
@@ -156,22 +164,24 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
float
*
quantiles1
,
float
*
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
bool
skip_zeros
,
int
n
)
float
*
quantiles1
,
float
*
quantiles2
,
float
*
absmax1
,
float
*
absmax2
,
float
weight_decay
,
const
float
gnorm_scale
,
bool
skip_zeros
,
int
n
)
{
{
int
blocks
=
0
;
int
num_
blocks
=
0
;
switch
(
OPTIMIZER
)
switch
(
OPTIMIZER
)
{
{
case
ADAM
:
case
ADAM
:
blocks
=
n
/
BLOCKSIZE_2STATE
;
num_blocks
=
n
/
BLOCKSIZE_2STATE
;
blocks
=
n
%
BLOCKSIZE_2STATE
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
BLOCKSIZE_2STATE
==
0
?
num_blocks
:
num_blocks
+
1
;
kOptimizerStatic8bit2StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_2STATE
,
NUM_2STATE
><<<
blocks
,
BLOCKSIZE_2STATE
/
NUM_2STATE
>>>
(
p
,
g
,
state1
,
state2
,
beta1
,
beta2
,
eps
,
step
,
lr
,
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
kOptimizerStatic8bit2StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_2STATE
,
NUM_2STATE
><<<
num_blocks
,
BLOCKSIZE_2STATE
/
NUM_2STATE
>>>
(
p
,
g
,
state1
,
state2
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
quantiles2
,
absmax1
,
absmax2
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
n
);
quantiles1
,
quantiles2
,
absmax1
,
absmax2
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
case
MOMENTUM
:
case
MOMENTUM
:
case
RMSPROP
:
case
RMSPROP
:
case
ADAGRAD
:
case
ADAGRAD
:
blocks
=
n
/
BLOCKSIZE_1STATE
;
num_blocks
=
n
/
BLOCKSIZE_1STATE
;
blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
BLOCKSIZE_1STATE
==
0
?
num_blocks
:
num_blocks
+
1
;
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
kOptimizerStatic8bit1StateBlockwise
<
T
,
OPTIMIZER
,
BLOCKSIZE_1STATE
,
NUM_1STATE
><<<
num_blocks
,
BLOCKSIZE_1STATE
/
NUM_1STATE
>>>
(
p
,
g
,
state1
,
beta1
,
beta2
,
eps
,
step
,
lr
,
quantiles1
,
absmax1
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
n
);
quantiles1
,
absmax1
,
weight_decay
,
gnorm_scale
,
skip_zeros
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
break
;
break
;
...
@@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
...
@@ -182,10 +192,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template
<
typename
T
>
void
percentileClipping
(
T
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
)
template
<
typename
T
>
void
percentileClipping
(
T
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
)
{
{
int
blocks
=
n
/
2048
;
int
num_blocks
=
n
/
2048
;
blocks
=
n
%
2048
==
0
?
blocks
:
blocks
+
1
;
num_blocks
=
n
%
2048
==
0
?
num_blocks
:
num_blocks
+
1
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
CUDA_CHECK_RETURN
(
cudaMemset
(
&
gnorm_vec
[
step
%
100
],
0
,
1
*
sizeof
(
float
)));
CUDA_CHECK_RETURN
(
cudaMemset
(
&
gnorm_vec
[
step
%
100
],
0
,
1
*
sizeof
(
float
)));
kPercentileClipping
<
T
,
2048
,
4
><<<
blocks
,
512
>>>
(
g
,
gnorm_vec
,
step
,
n
);
kPercentileClipping
<
T
,
2048
,
4
><<<
num_
blocks
,
512
>>>
(
g
,
gnorm_vec
,
step
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
@@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
...
@@ -445,10 +456,9 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
int
num_blocks
=
numRows
/
subtile_rows
;
int
num_blocks
=
numRows
/
subtile_rows
;
num_blocks
+=
(
numRows
%
subtile_rows
==
0
)
?
0
:
1
;
num_blocks
+=
(
numRows
%
subtile_rows
==
0
)
?
0
:
1
;
num_blocks
=
num_blocks
*
(
tileCols
/
32
);
num_blocks
=
num_blocks
*
(
tileCols
/
32
);
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
assert
(
threads
<=
tilesize
);
assert
(
threads
<=
tilesize
);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16
<
4
,
128
,
512
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
,
tileCols
,
n
);
kdequant_mm_int32_fp16
<
4
,
128
,
512
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
,
tileCols
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
@@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
...
@@ -461,7 +471,13 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int
tile_cols
=
STATS_THREADS
*
STATS_ITEMS
;
int
tile_cols
=
STATS_THREADS
*
STATS_ITEMS
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
STATS_ROWS
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
STATS_ROWS
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
STATS_ROWS
);
int
row_tiles
=
(
tiledRows
/
STATS_ROWS
);
int
col_tiles
=
(
tiledCols
/
tile_cols
);
row_tiles
=
row_tiles
>
0
?
row_tiles
:
1
;
col_tiles
=
col_tiles
>
0
?
col_tiles
:
1
;
int
num_blocks
=
row_tiles
*
col_tiles
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
if
(
nnz_threshold
==
0.0
)
if
(
nnz_threshold
==
0.0
)
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
0
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
0
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
...
@@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
...
@@ -479,12 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int
tile_rows
=
16
;
int
tile_rows
=
16
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
tile_rows
);
int
row_tiles
=
(
tiledRows
/
tile_rows
);
int
col_tiles
=
(
tiledCols
/
tile_cols
);
row_tiles
=
row_tiles
>
0
?
row_tiles
:
1
;
col_tiles
=
col_tiles
>
0
?
col_tiles
:
1
;
int
num_blocks
=
row_tiles
*
col_tiles
;
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
if
(
threshold
>
0.0
f
)
if
(
threshold
>
0.0
f
)
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_block_ptr
,
threshold
,
rows
,
cols
,
tiledCols
);
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_block_ptr
,
threshold
,
rows
,
cols
,
tiledCols
);
else
else
...
@@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
...
@@ -502,7 +520,13 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int
tile_rows
=
32
;
int
tile_rows
=
32
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
tile_rows
);
int
row_tiles
=
(
tiledRows
/
tile_rows
);
int
col_tiles
=
(
tiledCols
/
tile_cols
);
row_tiles
=
row_tiles
>
0
?
row_tiles
:
1
;
col_tiles
=
col_tiles
>
0
?
col_tiles
:
1
;
int
num_blocks
=
row_tiles
*
col_tiles
;
assert
(
num_blocks
<=
65535
&&
"CUDA ERROR: Maximum number of blocks for kernel exceeded"
);
int
outCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
outCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
outRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
int
outRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
if
(
FORMAT
==
COL_TURING
)
if
(
FORMAT
==
COL_TURING
)
...
@@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
...
@@ -528,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
}
}
}
}
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
TRANSPOSE
,
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
out
,
rows
,
cols
,
tiledCols
,
outRows
,
outCols
);
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
TRANSPOSE
,
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
out
,
rows
,
cols
,
tiledCols
,
outRows
,
outCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
...
tests/test_autograd.py
View file @
758c7175
...
@@ -40,6 +40,7 @@ names = [
...
@@ -40,6 +40,7 @@ names = [
ids
=
names
,
ids
=
names
,
)
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
if
dim2
>
0
:
dim2
=
dim2
-
(
dim2
%
16
)
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
...
@@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
# dim1 = (17,)
dim2
.
append
(
0
)
# dim2 = (7,)
# dim3 = (37,)
# dim4 = (23,)
decomp
=
[
0.0
,
6.0
]
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
...
@@ -385,9 +383,14 @@ def test_matmullt(
...
@@ -385,9 +383,14 @@ def test_matmullt(
)
)
if
req_grad
[
1
]:
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
n
=
gradB1
.
numel
()
if
dim2
>
0
:
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
else
:
assert
torch
.
abs
(
gradB1
).
sum
()
==
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
==
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment