Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c771b3a7
Commit
c771b3a7
authored
Jul 22, 2022
by
Tim Dettmers
Browse files
Most tests passing.
parent
4cd7ea62
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
5269 additions
and
159 deletions
+5269
-159
bitsandbytes/__init__.py
bitsandbytes/__init__.py
+2
-1
bitsandbytes/autograd/__init__.py
bitsandbytes/autograd/__init__.py
+0
-0
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+307
-0
bitsandbytes/cextension.py
bitsandbytes/cextension.py
+2
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+867
-2
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+122
-2
csrc/kernels.cu
csrc/kernels.cu
+874
-0
csrc/kernels.cuh
csrc/kernels.cuh
+12
-0
csrc/ops.cu
csrc/ops.cu
+406
-0
csrc/ops.cuh
csrc/ops.cuh
+104
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+126
-1
tests/test_autograd.py
tests/test_autograd.py
+270
-0
tests/test_functional.py
tests/test_functional.py
+1704
-59
tests/test_modules.py
tests/test_modules.py
+453
-25
tests/test_optim.py
tests/test_optim.py
+19
-68
No files found.
bitsandbytes/__init__.py
View file @
c771b3a7
...
@@ -4,12 +4,13 @@
...
@@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.nn
import
modules
from
.nn
import
modules
from
.autograd._functions
import
mm_cublas
,
bmm_cublas
,
matmul_cublas
,
matmul
,
MatmulLtState
from
.cextension
import
COMPILED_WITH_CUDA
from
.cextension
import
COMPILED_WITH_CUDA
if
COMPILED_WITH_CUDA
:
if
COMPILED_WITH_CUDA
:
from
.optim
import
adam
from
.optim
import
adam
__pdoc__
=
{
'lib
B
its
NB
ytes'
:
False
,
__pdoc__
=
{
'lib
b
its
andb
ytes'
:
False
,
'optim.optimizer.Optimizer8bit'
:
False
,
'optim.optimizer.Optimizer8bit'
:
False
,
'optim.optimizer.MockArgs'
:
False
'optim.optimizer.MockArgs'
:
False
}
}
bitsandbytes/autograd/__init__.py
0 → 100644
View file @
c771b3a7
bitsandbytes/autograd/_functions.py
0 → 100644
View file @
c771b3a7
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
from
dataclasses
import
dataclass
tensor
=
torch
.
Tensor
'''
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
class
GlobalOutlierPooler
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'Call get_instance() instead'
)
def
initialize
(
self
):
self
.
outliers
=
set
()
self
.
model_dim
=
None
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
.
initialize
()
return
cls
.
_instance
def
add_outliers
(
self
,
outlier_idx
,
feature_dim
):
if
self
.
model_dim
is
None
:
self
.
model_dim
=
feature_dim
if
feature_dim
!=
self
.
model_dim
:
return
# we do not encode outliers for the 2nd FFN layer
self
.
outliers
.
update
(
outlier_idx
.
tolist
())
def
get_current_outlier_idx
(
self
):
return
torch
.
Tensor
(
list
(
self
.
outliers
)).
to
(
torch
.
int64
)
class
MatMul8bit
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
quant_type
=
'vector'
,
precision
=
[
8
,
8
,
8
]):
if
precision
[
0
]
!=
8
:
with
torch
.
no_grad
():
output
=
torch
.
matmul
(
A
,
B
)
else
:
if
len
(
B
.
shape
)
==
2
:
dim
=
0
else
:
dim
=
1
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=-
1
,
quant_type
=
quant_type
)
qB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
dim
,
quant_type
=
quant_type
)
iout
=
F
.
igemm
(
qA
,
qB
)
output
=
F
.
vectorwise_mm_dequant
(
iout
,
SA
,
SB
,
A
.
dtype
,
quant_type
)
if
A
.
requires_grad
or
B
.
requires_grad
:
ctx
.
save_for_backward
(
A
,
B
)
ctx
.
quant_type
=
quant_type
ctx
.
precision
=
precision
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
A
,
B
=
ctx
.
saved_tensors
quant_type
=
ctx
.
quant_type
precision
=
ctx
.
precision
grad_A
=
grad_B
=
None
if
B
.
requires_grad
:
if
len
(
A
.
shape
)
==
3
:
dims
=
[
0
,
1
]
# bsi -> ibs
permute_dim
=
[
0
,
2
,
1
]
else
:
dims
=
[
0
]
# bs -> sb
permute_dim
=
[
1
,
0
]
if
precision
[
1
]
!=
8
:
with
torch
.
no_grad
():
grad_B
=
torch
.
matmul
(
A
.
permute
(
permute_dim
),
grad_output
)
else
:
if
len
(
B
.
shape
)
==
2
and
len
(
A
.
shape
)
==
3
:
grad_output
=
grad_output
.
contiguous
()
if
not
grad_output
.
is_contiguous
():
grad_output
.
contiguous
()
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
if
not
A
.
is_contiguous
():
A
=
A
.
contiguous
()
qA
,
S2
=
F
.
vectorwise_quant
(
A
.
view
(
-
1
,
A
.
shape
[
2
]),
dim
=
0
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
t
(),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
t
(),
S1
,
grad_output
.
dtype
,
quant_type
)
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qA
,
S2
=
F
.
vectorwise_quant
(
A
,
dim
=
dims
,
quant_type
=
quant_type
)
igrad_B
=
F
.
igemm
(
qA
.
permute
(
permute_dim
),
qgrad_output
)
grad_B
=
F
.
vectorwise_mm_dequant
(
igrad_B
,
S2
.
permute
(
permute_dim
),
S1
,
grad_output
.
dtype
,
quant_type
)
if
A
.
requires_grad
:
if
len
(
grad_output
.
shape
)
==
3
:
dims
=
[
2
]
else
:
dims
=
[
1
]
if
len
(
B
.
shape
)
==
3
:
# bio -> boi
permute_dim
=
[
0
,
2
,
1
]
dim_B
=
dims
else
:
# io -> oi
permute_dim
=
[
1
,
0
]
dim_B
=
[
1
]
if
precision
[
2
]
!=
8
:
with
torch
.
no_grad
():
grad_A
=
torch
.
matmul
(
grad_output
,
B
.
permute
(
permute_dim
))
else
:
qgrad_output
,
S1
=
F
.
vectorwise_quant
(
grad_output
,
dim
=
dims
,
quant_type
=
quant_type
)
qB
,
S3
=
F
.
vectorwise_quant
(
B
,
dim
=
dim_B
,
quant_type
=
quant_type
)
igrad_A
=
F
.
igemm
(
qgrad_output
,
qB
.
permute
(
permute_dim
))
grad_A
=
F
.
vectorwise_mm_dequant
(
igrad_A
,
S1
,
S3
.
permute
(
permute_dim
),
grad_output
.
dtype
,
quant_type
)
return
grad_A
,
grad_B
,
None
,
None
,
None
mm_cublas
=
MatMul8bit
.
apply
bmm_cublas
=
MatMul8bit
.
apply
matmul_cublas
=
MatMul8bit
.
apply
@
dataclass
class
MatmulLtState
:
CB
=
None
CxB
=
None
SB
=
None
SCB
=
None
CxBt
=
None
SBt
=
None
CBt
=
None
subB
=
None
outlier_pool
=
None
has_accumulated_gradients
=
False
threshold
=
0.0
idx
=
None
is_training
=
True
has_fp16_weights
=
True
use_pool
=
False
formatB
=
F
.
get_special_format_str
()
def
reset_grads
(
self
):
self
.
CB
=
None
self
.
CxB
=
None
self
.
SB
=
None
self
.
SCB
=
None
self
.
CxBt
=
None
self
.
SBt
=
None
self
.
CBt
=
None
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
state
=
MatmulLtState
()):
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
requires_gradA
=
A
.
requires_grad
requires_gradB
=
B
.
requires_grad
formatB
=
state
.
formatB
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
assert
A
.
dtype
==
torch
.
float16
,
f
'The input data type needs to be fp16 but
{
A
.
dtype
}
was found!'
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
has_fp16_weights
:
idx
=
torch
.
unique
(
coo_tensorA
.
colidx
).
long
()
CA
[:,
idx
]
=
0
CAt
[:,
idx
]
=
0
subA
=
A
[:,
idx
]
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
idx
=
idx
else
:
if
state
.
CxB
is
None
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
and
state
.
idx
is
None
and
state
.
CB
is
not
None
:
# generate outlier index and subB
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
).
long
()
state
.
outlier_pool
.
add_outliers
(
outlier_idx
,
A
.
shape
[
-
1
])
if
state
.
use_pool
and
state
.
outlier_pool
.
model_dim
==
A
.
shape
[
-
1
]:
# do not use pool for 2nd FFN layer
state
.
idx
=
state
.
outlier_pool
.
get_current_outlier_idx
().
to
(
A
.
device
)
else
:
state
.
idx
=
outlier_idx
state
.
subB
=
(
state
.
CB
[:,
state
.
idx
].
float
().
t
().
contiguous
()
*
(
state
.
SCB
/
127
)).
half
()
if
state
.
idx
is
not
None
:
# extract outliers
CA
[:,
state
.
idx
]
=
0
CAt
[:,
state
.
idx
]
=
0
subA
=
A
[:,
state
.
idx
]
else
:
subA
=
None
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
# 2. Quantize B
if
state
.
has_fp16_weights
:
has_grad
=
(
True
if
(
getattr
(
B
,
'grad'
,
None
)
is
not
None
)
else
False
)
is_transposed
=
not
B
.
is_contiguous
()
and
B
.
shape
[
0
]
==
B
.
stride
(
1
)
if
is_transposed
:
B
=
B
.
contiguous
()
if
(
state
.
is_training
and
not
has_grad
)
or
state
.
CxB
is
None
:
state
.
reset_grads
()
CB
,
state
.
CBt
,
state
.
SCB
,
state
.
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
else
:
has_grad
=
False
shapeB
=
state
.
SB
[
0
]
if
len
(
input_shape
)
==
3
:
output_shape
=
(
input_shape
[
0
],
input_shape
[
1
],
shapeB
[
0
])
else
:
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
# 4. Mixed-precision decomposition matmul
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
and
subA
is
not
None
:
output
+=
torch
.
matmul
(
subA
,
state
.
subB
)
# 5. Save state
ctx
.
state
=
state
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
]
if
requires_gradA
or
requires_gradB
:
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensor_states
=
(
SCAt
,
state
.
idx
)
else
:
ctx
.
tensors
=
[
None
,
None
]
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func
=
torch
.
clone
return
clone_func
(
output
.
view
(
output_shape
))
@
staticmethod
def
backward
(
ctx
,
grad_output
):
req_gradA
,
req_gradB
=
ctx
.
req_grads
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
state
=
ctx
.
state
assert
state
.
has_fp16_weights
,
'Backprop only supported for fp16 weights.'
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
grad_A
=
grad_B
=
None
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
'col32'
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
'col32'
)
if
state
.
CxBt
is
None
:
state
.
CxBt
,
state
.
SBt
=
F
.
transform
(
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
state
:
MatmulLtState
=
None
,
threshold
=
0.0
):
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
state
)
bitsandbytes/cextension.py
View file @
c771b3a7
...
@@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
...
@@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try
:
try
:
lib
.
cadam32bit_g32
lib
.
cadam32bit_g32
lib
.
get_context
.
restype
=
ct
.
c_void_p
lib
.
get_cusparse
.
restype
=
ct
.
c_void_p
COMPILED_WITH_CUDA
=
True
COMPILED_WITH_CUDA
=
True
except
AttributeError
:
except
AttributeError
:
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
warn
(
"The installed version of bitsandbytes was compiled without GPU support. "
...
...
bitsandbytes/functional.py
View file @
c771b3a7
...
@@ -36,9 +36,51 @@ if COMPILED_WITH_CUDA:
...
@@ -36,9 +36,51 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise
[
'rmsprop'
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'rmsprop'
]
=
(
lib
.
crmsprop_8bit_blockwise_fp32
,
lib
.
crmsprop_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'adagrad'
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
)
str2optimizer8bit_blockwise
[
'adagrad'
]
=
(
lib
.
cadagrad_8bit_blockwise_fp32
,
lib
.
cadagrad_8bit_blockwise_fp16
)
optimal_normal
=
[
-
0.9939730167388916
,
-
0.8727636337280273
,
-
0.8097418546676636
,
-
0.7660024166107178
,
-
0.7318882346153259
,
-
0.6793879270553589
,
-
0.657649040222168
,
-
0.6385974884033203
,
-
0.6211113333702087
,
-
0.5901028513908386
,
-
0.5762918591499329
,
-
0.5630806684494019
,
-
0.5509274005889893
,
-
0.5394591689109802
,
-
0.5283197164535522
,
-
0.517780065536499
,
-
0.5074946284294128
,
-
0.4980469048023224
,
-
0.48867011070251465
,
-
0.48003149032592773
,
-
0.47125306725502014
,
-
0.4629971981048584
,
-
0.4547359049320221
,
-
0.446626216173172
,
-
0.43902668356895447
,
-
0.43158355355262756
,
-
0.4244747757911682
,
-
0.4173796474933624
,
-
0.41038978099823
,
-
0.4055633544921875
,
-
0.4035947024822235
,
-
0.39701032638549805
,
-
0.39057496190071106
,
-
0.38439232110977173
,
-
0.3782760500907898
,
-
0.3721940815448761
,
-
0.3661896586418152
,
-
0.3604033589363098
,
-
0.354605108499527
,
-
0.34892538189888
,
-
0.34320303797721863
,
-
0.3376772701740265
,
-
0.3323028087615967
,
-
0.3269782066345215
,
-
0.32166096568107605
,
-
0.316457599401474
,
-
0.3112771809101105
,
-
0.3061025142669678
,
-
0.30106794834136963
,
-
0.2961243987083435
,
-
0.2912728488445282
,
-
0.28644347190856934
,
-
0.28165507316589355
,
-
0.2769731283187866
,
-
0.2722635865211487
,
-
0.26779335737228394
,
-
0.26314786076545715
,
-
0.2586647868156433
,
-
0.2541804611682892
,
-
0.2496625930070877
,
-
0.24527113139629364
,
-
0.24097171425819397
,
-
0.23659978806972504
,
-
0.23218469321727753
,
-
0.22799566388130188
,
-
0.22380566596984863
,
-
0.21965542435646057
,
-
0.2154538631439209
,
-
0.2113603949546814
,
-
0.20735277235507965
,
-
0.20334717631340027
,
-
0.19932441413402557
,
-
0.19530178606510162
,
-
0.19136647880077362
,
-
0.18736697733402252
,
-
0.18337111175060272
,
-
0.17951400578022003
,
-
0.1757056713104248
,
-
0.17182783782482147
,
-
0.1680615097284317
,
-
0.16431649029254913
,
-
0.16053077578544617
,
-
0.15685945749282837
,
-
0.15298527479171753
,
-
0.1493264138698578
,
-
0.14566898345947266
,
-
0.14188314974308014
,
-
0.13819937407970428
,
-
0.1344561129808426
,
-
0.1306886374950409
,
-
0.1271020770072937
,
-
0.12346585839986801
,
-
0.11981867253780365
,
-
0.11614970862865448
,
-
0.11256207525730133
,
-
0.10889036953449249
,
-
0.10525048524141312
,
-
0.1016591489315033
,
-
0.09824034571647644
,
-
0.09469068050384521
,
-
0.0911419615149498
,
-
0.08773849159479141
,
-
0.08416644483804703
,
-
0.08071305602788925
,
-
0.07720902562141418
,
-
0.07371306419372559
,
-
0.07019119709730148
,
-
0.06673648208379745
,
-
0.06329209357500076
,
-
0.059800852090120316
,
-
0.0564190037548542
,
-
0.05296570807695389
,
-
0.049522045999765396
,
-
0.04609023034572601
,
-
0.04262964054942131
,
-
0.039246633648872375
,
-
0.03577171266078949
,
-
0.03236335143446922
,
-
0.028855687007308006
,
-
0.02542758360505104
,
-
0.022069433704018593
,
-
0.018754752352833748
,
-
0.015386369079351425
,
-
0.01194947212934494
,
-
0.008439815603196621
,
-
0.004995611496269703
,
-
0.0016682245768606663
,
0.0
,
0.0015510577941313386
,
0.005062474869191647
,
0.008417150937020779
,
0.011741090565919876
,
0.015184164978563786
,
0.018582714721560478
,
0.02204744517803192
,
0.025471193715929985
,
0.02889077737927437
,
0.0323684960603714
,
0.03579240292310715
,
0.039281025528907776
,
0.0427563451230526
,
0.04619763046503067
,
0.04968220740556717
,
0.05326594039797783
,
0.05679265409708023
,
0.060245808213949203
,
0.06372645497322083
,
0.06721872836351395
,
0.0706876739859581
,
0.0742349922657013
,
0.07774098962545395
,
0.08123527467250824
,
0.08468879014253616
,
0.08810535818338394
,
0.09155989438295364
,
0.09498448669910431
,
0.0985206812620163
,
0.10206405073404312
,
0.10563778132200241
,
0.10921968519687653
,
0.11284469068050385
,
0.11653254181146622
,
0.12008969485759735
,
0.12368203699588776
,
0.1272617131471634
,
0.13089501857757568
,
0.134552001953125
,
0.1382799744606018
,
0.14194637537002563
,
0.14563234150409698
,
0.14930322766304016
,
0.15303383767604828
,
0.1567956507205963
,
0.16050070524215698
,
0.16431072354316711
,
0.16813558340072632
,
0.17204202711582184
,
0.1758781224489212
,
0.17973239719867706
,
0.1836014688014984
,
0.18753431737422943
,
0.19138391315937042
,
0.19535475969314575
,
0.19931404292583466
,
0.20333819091320038
,
0.20738255977630615
,
0.21152682602405548
,
0.21568812429904938
,
0.21978361904621124
,
0.22393859922885895
,
0.22814159095287323
,
0.23241068422794342
,
0.23675410449504852
,
0.24123944342136383
,
0.24569889903068542
,
0.2500703036785126
,
0.25904011726379395
,
0.26349544525146484
,
0.2682226300239563
,
0.272907555103302
,
0.2774306833744049
,
0.28220856189727783
,
0.2869136929512024
,
0.2916390895843506
,
0.29649388790130615
,
0.30142995715141296
,
0.3065022826194763
,
0.3114383816719055
,
0.31648796796798706
,
0.3216581642627716
,
0.32700115442276
,
0.3322487473487854
,
0.33778008818626404
,
0.3431521952152252
,
0.3487405776977539
,
0.3543166518211365
,
0.3601346015930176
,
0.36605337262153625
,
0.37217751145362854
,
0.378179669380188
,
0.3843980133533478
,
0.3906566798686981
,
0.39714935421943665
,
0.40357843041419983
,
0.4104187488555908
,
0.4171563684940338
,
0.42418959736824036
,
0.43136918544769287
,
0.4389212429523468
,
0.44673123955726624
,
0.45457619428634644
,
0.4627031683921814
,
0.47130417823791504
,
0.4798591434955597
,
0.48897242546081543
,
0.4979848861694336
,
0.5
,
0.5076631307601929
,
0.5177803635597229
,
0.5282770991325378
,
0.5392990112304688
,
0.5506287813186646
,
0.5632893443107605
,
0.5764452815055847
,
0.5903191566467285
,
0.6051878333091736
,
0.6209936141967773
,
0.6382884979248047
,
0.6573970913887024
,
0.6795773506164551
,
0.7037051916122437
,
0.7327037453651428
,
0.7677436470985413
,
0.8111193776130676
,
0.875165581703186
,
1.0
]
optimal_half_normal
=
[
0.0025565922260284424
,
0.005811259150505066
,
0.00961565226316452
,
0.010822802782058716
,
0.013123787939548492
,
0.014242202043533325
,
0.0143156498670578
,
0.016469404101371765
,
0.017666727304458618
,
0.01773911714553833
,
0.0199756920337677
,
0.0210941880941391
,
0.021161124110221863
,
0.02451971173286438
,
0.024580076336860657
,
0.02685210108757019
,
0.028012827038764954
,
0.030198264867067337
,
0.0302925705909729
,
0.03136435151100159
,
0.03374280035495758
,
0.03487399220466614
,
0.035243816673755646
,
0.037192340940237045
,
0.03822284936904907
,
0.04164902865886688
,
0.04173608124256134
,
0.04401407018303871
,
0.04508155584335327
,
0.047482021152973175
,
0.04756556823849678
,
0.050963032990694046
,
0.05196474492549896
,
0.055417388677597046
,
0.05793146416544914
,
0.05799369141459465
,
0.05887940526008606
,
0.05895659327507019
,
0.062420234084129333
,
0.06493274495005608
,
0.06499008461833
,
0.06935599446296692
,
0.07197384163737297
,
0.07201516255736351
,
0.07276943325996399
,
0.07283210754394531
,
0.07550075277686119
,
0.07975354790687561
,
0.07980883121490479
,
0.08257630094885826
,
0.0867777168750763
,
0.08682405948638916
,
0.08967285975813866
,
0.09323835000395775
,
0.09386616945266724
,
0.09735457599163055
,
0.09739077091217041
,
0.10092401504516602
,
0.10444298386573792
,
0.10447832942008972
,
0.10770941898226738
,
0.10803905129432678
,
0.11161200702190399
,
0.1151546835899353
,
0.11520349979400635
,
0.11875157058238983
,
0.11879390478134155
,
0.1222602017223835
,
0.122351735830307
,
0.12240418791770935
,
0.12594850733876228
,
0.12597402930259705
,
0.12602100148797035
,
0.12960633635520935
,
0.1296597123146057
,
0.12966342642903328
,
0.13227657973766327
,
0.13325360417366028
,
0.1333133578300476
,
0.13691483438014984
,
0.1371927298605442
,
0.14066261053085327
,
0.14088113978505135
,
0.1447291411459446
,
0.14805573225021362
,
0.148526418954134
,
0.15170684456825256
,
0.15178103744983673
,
0.15225710347294807
,
0.1554398238658905
,
0.15609459951519966
,
0.15618794038891792
,
0.1592724472284317
,
0.1629735231399536
,
0.16382690146565437
,
0.16676269471645355
,
0.16873238794505596
,
0.17066434025764465
,
0.17068277299404144
,
0.1717144437134266
,
0.17558929696679115
,
0.17827065289020538
,
0.17835864424705505
,
0.18222273886203766
,
0.18353315070271492
,
0.18604370951652527
,
0.18611834943294525
,
0.1876586265861988
,
0.18996606767177582
,
0.19170701876282692
,
0.19398853182792664
,
0.19786442816257477
,
0.19795633852481842
,
0.20195159316062927
,
0.2058800607919693
,
0.2099103182554245
,
0.2122517265379429
,
0.21410366892814636
,
0.21819619834423065
,
0.22221362590789795
,
0.22233009338378906
,
0.22500130906701088
,
0.2251257635653019
,
0.22638091444969177
,
0.23067741096019745
,
0.23368822410702705
,
0.2348879873752594
,
0.2382080741226673
,
0.2390350103378296
,
0.2391497790813446
,
0.24253453686833382
,
0.24265171959996223
,
0.2470107562839985
,
0.24764248728752136
,
0.24777774512767792
,
0.2516774423420429
,
0.256104726344347
,
0.2564055472612381
,
0.2607169933617115
,
0.265461727976799
,
0.26985861361026764
,
0.2701106257736683
,
0.2702729292213917
,
0.274574413895607
,
0.2750340588390827
,
0.27919672429561615
,
0.283704474568367
,
0.28386808931827545
,
0.28953738883137703
,
0.2896753139793873
,
0.29320384562015533
,
0.29451676085591316
,
0.295327290892601
,
0.29802779853343964
,
0.29818175733089447
,
0.29972871020436287
,
0.30290623009204865
,
0.30305664241313934
,
0.30486901476979256
,
0.31299956142902374
,
0.31518544629216194
,
0.31790371239185333
,
0.3205283172428608
,
0.3230419009923935
,
0.32595496252179146
,
0.32612212374806404
,
0.3282426446676254
,
0.3283906430006027
,
0.33146094158291817
,
0.3316439874470234
,
0.33365286886692047
,
0.33723779395222664
,
0.3390095978975296
,
0.3427443392574787
,
0.34853987768292427
,
0.34869300201535225
,
0.35457711294293404
,
0.35537679493427277
,
0.3604113645851612
,
0.36124424636363983
,
0.3665340431034565
,
0.36667295172810555
,
0.3727492541074753
,
0.3729033060371876
,
0.37888188660144806
,
0.37907837703824043
,
0.3792510814964771
,
0.38557394221425056
,
0.38573457673192024
,
0.39108292758464813
,
0.39911722019314766
,
0.40589402988553047
,
0.40604450181126595
,
0.410498782992363
,
0.4106704741716385
,
0.4129834659397602
,
0.4131447561085224
,
0.4172855168581009
,
0.4202354736626148
,
0.4204071946442127
,
0.43538858368992805
,
0.4355536885559559
,
0.4432900734245777
,
0.44603554904460907
,
0.4461968094110489
,
0.451409537345171
,
0.4598204083740711
,
0.46002377942204475
,
0.46178819239139557
,
0.46868549659848213
,
0.46995367109775543
,
0.4868385046720505
,
0.48702501133084297
,
0.4958047419786453
,
0.4960057884454727
,
0.5051481872797012
,
0.506847757846117
,
0.5148334950208664
,
0.5150565356016159
,
0.5174009390175343
,
0.5249751061201096
,
0.5283288545906544
,
0.5355450958013535
,
0.539984006434679
,
0.5467876642942429
,
0.5522958822548389
,
0.5584012717008591
,
0.5706631988286972
,
0.5836620181798935
,
0.5836880058050156
,
0.5942088551819324
,
0.5975865572690964
,
0.6102624125778675
,
0.6124880760908127
,
0.6286389082670212
,
0.646102175116539
,
0.6471664495766163
,
0.665437325835228
,
0.6687244363129139
,
0.687017485499382
,
0.6932839937508106
,
0.7115348428487778
,
0.7218200154602528
,
0.7219699807465076
,
0.7747527211904526
,
0.7749756425619125
,
0.8192005604505539
,
0.8194110840559006
,
0.8830635994672775
,
0.9217727445065975
,
0.9245667457580566
,
0.947742685675621
,
0.9674464613199234
,
0.9890814647078514
,
0.9891453236341476
,
0.9925699159502983
]
class
CUBLAS_Context
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'Call get_instance() instead'
)
def
initialize
(
self
):
self
.
context
=
{}
#prev_device = torch.cuda.current_device()
#for i in range(torch.cuda.device_count()):
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
#torch.cuda.set_device(prev_device)
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
.
initialize
()
return
cls
.
_instance
def
get_context
(
self
,
device
):
if
device
.
index
not
in
self
.
context
:
prev_device
=
torch
.
cuda
.
current_device
()
torch
.
cuda
.
set_device
(
device
)
self
.
context
[
device
.
index
]
=
ct
.
c_void_p
(
lib
.
get_context
())
torch
.
cuda
.
set_device
(
prev_device
)
return
self
.
context
[
device
.
index
]
class
Cusparse_Context
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'Call get_instance() instead'
)
def
initialize
(
self
):
self
.
context
=
ct
.
c_void_p
(
lib
.
get_cusparse
())
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
.
initialize
()
return
cls
.
_instance
def
create_linear_map
(
signed
=
True
):
def
create_linear_map
(
signed
=
True
):
if
signed
:
if
signed
:
...
@@ -89,6 +131,16 @@ def create_dynamic_map(signed=True, n=7):
...
@@ -89,6 +131,16 @@ def create_dynamic_map(signed=True, n=7):
data
.
sort
()
data
.
sort
()
return
Tensor
(
data
)
return
Tensor
(
data
)
def
get_special_format_str
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
if
major
<
7
:
print
(
f
'Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!'
)
assert
major
>=
7
if
major
==
7
:
return
'col_turing'
elif
major
==
8
:
return
'col_ampere'
else
:
return
'col_turing'
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
def
get_ptr
(
A
:
Tensor
)
->
ct
.
c_void_p
:
'''
'''
Get the ctypes pointer from a PyTorch Tensor.
Get the ctypes pointer from a PyTorch Tensor.
...
@@ -105,6 +157,105 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
...
@@ -105,6 +157,105 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
if
A
is
None
:
return
None
if
A
is
None
:
return
None
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
def
pre_call
(
device
):
prev_device
=
torch
.
cuda
.
current_device
()
torch
.
cuda
.
set_device
(
device
)
return
prev_device
def
post_call
(
prev_device
):
torch
.
cuda
.
set_device
(
prev_device
)
def
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
=
False
):
name
=
f
'ctransform_
{
(
8
if
dtype
==
torch
.
int8
else
32
)
}
_
{
orderA
}
_to_
{
orderOut
}
_
{
"t"
if
transpose
else
"n"
}
'
if
not
hasattr
(
lib
,
name
):
print
(
name
)
raise
ValueError
(
f
'Transform function not supported:
{
orderA
}
to
{
orderOut
}
for data type
{
dtype
}
and transpose=
{
transpose
}
'
)
else
:
return
getattr
(
lib
,
name
)
class
GlobalData
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
'Call get_instance() instead'
)
def
initialize
(
self
):
self
.
data
=
{}
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
cls
.
_instance
.
initialize
()
return
cls
.
_instance
def
get_transform_buffer
(
shape
,
dtype
,
device
,
to_order
,
from_order
=
'row'
,
transpose
=
False
):
#init_func = torch.empty
init_func
=
torch
.
zeros
dims
=
len
(
shape
)
if
dims
==
2
:
rows
=
shape
[
0
]
elif
dims
==
3
:
rows
=
shape
[
0
]
*
shape
[
1
]
cols
=
shape
[
-
1
]
state
=
(
shape
,
to_order
)
if
transpose
:
# swap dims
tmp
=
rows
rows
=
cols
cols
=
tmp
state
=
(
shape
[::
-
1
],
to_order
)
if
to_order
==
'row'
or
to_order
==
'col'
:
return
init_func
(
shape
,
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'col32'
:
# blocks of 32 columns (padded)
cols
=
32
*
((
cols
+
31
)
//
32
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'col_turing'
:
# blocks of 32 columns and 8 rows
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
8
*
((
rows
+
7
)
//
8
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
elif
to_order
==
'col_ampere'
:
# blocks of 32 columns and 32 rows
cols
=
32
*
((
cols
+
31
)
//
32
)
rows
=
32
*
((
rows
+
31
)
//
32
)
return
init_func
((
rows
,
cols
),
dtype
=
dtype
,
device
=
device
),
state
else
:
raise
NotImplementedError
(
f
'To_order not supported:
{
to_order
}
'
)
def
nvidia_transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
])
else
:
new_state
=
(
state
[
1
],
to_order
)
func
=
get_transform_func
(
A
.
dtype
,
from_order
,
to_order
,
transpose
)
shape
=
state
[
0
]
if
len
(
shape
)
==
2
:
dim1
=
ct
.
c_int32
(
shape
[
0
])
dim2
=
ct
.
c_int32
(
shape
[
1
])
elif
ld
is
not
None
:
n
=
math
.
prod
(
shape
)
dim1
=
math
.
prod
([
shape
[
i
]
for
i
in
ld
])
dim2
=
ct
.
c_int32
(
n
//
dim1
)
dim1
=
ct
.
c_int32
(
dim1
)
else
:
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
func
(
ptr
,
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
return
out
,
new_state
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
def
estimate_quantiles
(
A
:
Tensor
,
out
:
Tensor
=
None
,
offset
:
float
=
1
/
512
)
->
Tensor
:
'''
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
Estimates 256 equidistant quantiles on the input tensor eCDF.
...
@@ -544,3 +695,717 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
...
@@ -544,3 +695,717 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
maxdim1
=
ct
.
c_int32
(
histogram
.
shape
[
0
])
n
=
ct
.
c_int32
(
index1
.
numel
())
n
=
ct
.
c_int32
(
index1
.
numel
())
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
)
lib
.
chistogram_scatter_add_2d
(
get_ptr
(
histogram
),
get_ptr
(
index1
),
get_ptr
(
index2
),
get_ptr
(
source
),
maxdim1
,
n
)
def
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
,
expected_type
=
torch
.
int8
):
if
not
torch
.
cuda
.
is_initialized
():
torch
.
cuda
.
init
()
if
A
.
dtype
!=
expected_type
or
B
.
dtype
!=
expected_type
:
raise
TypeError
(
f
'Expected torch.int8 input tensors A and B, but got
{
A
.
dtype
}
and
{
B
.
dtype
}
'
)
sA
=
A
.
shape
sB
=
B
.
shape
tA
=
transposed_A
tB
=
transposed_B
correct
=
True
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
0
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
0
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
tA
and
not
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
not
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
1
]:
correct
=
False
elif
tA
and
tB
and
A
.
shape
[
1
]
!=
B
.
shape
[
2
]:
correct
=
False
elif
not
tA
and
tB
and
A
.
shape
[
2
]
!=
B
.
shape
[
2
]:
correct
=
False
if
out
is
not
None
:
sout
=
out
.
shape
# special case common in backprop
if
not
correct
and
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
(
sout
[
0
]
==
sA
[
2
]
and
sout
[
1
]
==
sB
[
2
]
and
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
correct
=
True
else
:
if
len
(
sA
)
==
2
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sB
[
1
])
elif
tA
and
tB
:
sout
=
(
sA
[
1
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
1
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sB
[
0
])
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
0
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
0
])
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
3
:
if
not
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
2
])
elif
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
1
])
elif
tA
and
not
tB
:
sout
=
(
sA
[
0
],
sA
[
2
],
sB
[
2
])
elif
not
tA
and
tB
:
sout
=
(
sA
[
0
],
sA
[
1
],
sB
[
1
])
if
not
correct
:
raise
ValueError
(
f
'Tensor dimensions incorrect for matrix mulitiplication: A x B:
{
sA
}
x
{
sB
}
with transpose for A x B:
{
tA
}
x
{
tB
}
.'
)
return
sout
def
igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
len
(
A
.
shape
)
==
3
and
len
(
B
.
shape
)
==
3
:
if
A
.
shape
[
0
]
==
B
.
shape
[
0
]
and
A
.
shape
[
2
]
==
B
.
shape
[
1
]:
return
batched_igemm
(
A
,
B
,
out
)
sA
=
A
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
sA
=
(
sA
[
1
],
sA
[
0
])
elif
transposed_A
and
len
(
sA
)
==
3
:
sA
=
(
sA
[
0
],
sA
[
2
],
sA
[
0
])
if
transposed_B
and
len
(
sB
)
==
2
:
sB
=
(
sB
[
1
],
sB
[
0
])
elif
transposed_B
and
len
(
sB
)
==
3
:
sB
=
(
sB
[
0
],
sB
[
2
],
sB
[
0
])
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if
len
(
sB
)
==
2
:
if
B
.
stride
()[
0
]
==
B
.
shape
[
1
]:
transposed_B
=
False
elif
B
.
stride
()[
1
]
==
B
.
shape
[
0
]:
transposed_B
=
True
if
len
(
A
.
shape
)
==
2
:
if
A
.
stride
()[
0
]
==
A
.
shape
[
1
]:
transposed_A
=
False
elif
A
.
stride
()[
1
]
==
A
.
shape
[
0
]:
transposed_A
=
True
else
:
if
A
.
stride
()[
1
]
==
A
.
shape
[
2
]:
transposed_A
=
False
elif
A
.
stride
()[
2
]
==
A
.
shape
[
1
]:
transposed_A
=
True
if
len
(
sA
)
==
2
:
n
=
sA
[
0
]
ldb
=
A
.
stride
()[
1
if
transposed_A
else
0
]
elif
len
(
sA
)
==
3
and
len
(
sB
)
==
2
:
n
=
sA
[
0
]
*
sA
[
1
]
ldb
=
sA
[
2
]
m
=
sB
[
1
]
k
=
sB
[
0
]
lda
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
sB
[
1
]
elif
len
(
sB
)
==
3
:
# special case
assert
len
(
sA
)
==
3
if
not
(
sA
[
0
]
==
sB
[
0
]
and
sA
[
1
]
==
sB
[
1
]):
raise
ValueError
(
f
'Only bsi,bso->io supported for tensor contractions, but dims for A x B were:
{
sA
}
x
{
sB
}
'
)
transposed_A
=
True
transposed_B
=
False
m
=
sB
[
2
]
n
=
sA
[
2
]
k
=
sB
[
0
]
*
sB
[
1
]
lda
=
m
ldb
=
sA
[
2
]
ldc
=
m
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
# B^T @ A^T = C^T
# [km, nk -> mn]
lib
.
cigemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
))
return
out
def
batched_igemm
(
A
:
Tensor
,
B
:
Tensor
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_B
=
False
):
if
not
len
(
A
.
shape
)
==
3
or
not
len
(
B
.
shape
)
==
3
:
raise
ValueError
(
f
'Expected 3-dimensional tensors for bmm, but got shapes A and B:
{
A
.
shape
}
and
{
B
.
shape
}
'
)
sout
=
check_matmul
(
A
,
B
,
out
,
transposed_A
,
transposed_B
)
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
sout
,
dtype
=
torch
.
int32
,
device
=
A
.
device
)
if
B
.
is_contiguous
():
lda
=
B
.
stride
()[
1
]
transposed_A
=
False
else
:
s
=
B
.
stride
()
if
s
[
0
]
!=
B
.
shape
[
0
]:
B
=
B
.
contiguous
()
lda
=
B
.
stride
()[
1
]
elif
s
[
2
]
==
B
.
shape
[
1
]:
transposed_A
=
True
lda
=
B
.
stride
()[
2
]
else
:
if
s
[
2
]
==
1
:
B
=
B
.
contiguous
()
lda
=
B
.
stride
()[
1
]
elif
s
[
1
]
==
1
:
B
=
B
.
contiguous
()
lda
=
B
.
stride
()[
1
]
else
:
B
=
B
.
contiguous
()
lda
=
B
.
stride
()[
1
]
if
A
.
is_contiguous
():
ldb
=
A
.
stride
()[
1
]
transposed_B
=
False
else
:
s
=
A
.
stride
()
if
s
[
0
]
!=
A
.
shape
[
0
]:
A
=
A
.
contiguous
()
ldb
=
A
.
stride
()[
1
]
transposed_B
=
False
elif
s
[
2
]
==
A
.
shape
[
1
]:
ldb
=
A
.
stride
()[
2
]
transposed_B
=
True
else
:
A
=
A
.
contiguous
()
ldb
=
A
.
stride
()[
1
]
transposed_B
=
False
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
# row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
# column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
num_batch
=
A
.
shape
[
0
]
n
=
A
.
shape
[
1
]
m
=
B
.
shape
[
2
]
k
=
B
.
shape
[
1
]
ldc
=
m
strideA
=
B
.
shape
[
1
]
*
B
.
shape
[
2
]
strideB
=
A
.
shape
[
1
]
*
A
.
shape
[
2
]
strideC
=
A
.
shape
[
1
]
*
B
.
shape
[
2
]
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
lib
.
cbatched_igemm
(
ptr
,
ct
.
c_bool
(
transposed_B
),
ct
.
c_bool
(
transposed_A
),
ct
.
c_int32
(
m
),
ct
.
c_int32
(
n
),
ct
.
c_int32
(
k
),
get_ptr
(
B
),
get_ptr
(
A
),
get_ptr
(
out
),
ct
.
c_int32
(
lda
),
ct
.
c_int32
(
ldb
),
ct
.
c_int32
(
ldc
),
ct
.
c_long
(
strideA
),
ct
.
c_long
(
strideB
),
ct
.
c_long
(
strideC
),
ct
.
c_uint32
(
num_batch
))
return
out
def
igemmlt
(
A
,
B
,
SA
,
SB
,
out
=
None
,
Sout
=
None
,
row_scale
=
None
,
dtype
=
torch
.
int32
):
shapeA
=
SA
[
0
]
shapeB
=
SB
[
0
]
dimsA
=
len
(
shapeA
)
dimsB
=
len
(
shapeB
)
if
dimsA
==
2
:
m
=
shapeA
[
0
]
elif
dimsA
==
3
:
m
=
shapeA
[
0
]
*
shapeA
[
1
]
if
dimsB
==
2
:
rows
=
n
=
shapeB
[
0
]
elif
dimsB
==
3
:
rows
=
n
=
shapeB
[
0
]
*
shapeB
[
1
]
if
dimsA
==
2
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
elif
dimsA
==
3
and
out
is
None
:
out
,
Sout
=
get_transform_buffer
((
shapeA
[
0
],
shapeA
[
1
],
shapeB
[
0
]),
dtype
,
A
.
device
,
'col32'
,
'row'
)
if
row_scale
is
not
None
:
assert
row_scale
.
numel
()
==
out
.
shape
[
0
]
assert
dimsB
!=
3
,
'len(B.shape)==3 not supported'
assert
A
.
device
.
type
==
'cuda'
assert
B
.
device
.
type
==
'cuda'
assert
A
.
dtype
==
torch
.
int8
assert
B
.
dtype
==
torch
.
int8
assert
out
.
dtype
==
dtype
assert
SA
[
1
]
==
'col32'
assert
SB
[
1
]
in
[
'col_turing'
,
'col_ampere'
]
assert
Sout
[
1
]
==
'col32'
assert
shapeA
[
-
1
]
==
shapeB
[
-
1
],
f
'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B =
{
shapeA
}
@
{
shapeB
}
'
formatB
=
SB
[
1
]
prev_device
=
A
.
device
torch
.
cuda
.
set_device
(
A
.
device
)
ptr
=
CUBLAS_Context
.
get_instance
().
get_context
(
A
.
device
)
ptrA
=
get_ptr
(
A
)
ptrB
=
get_ptr
(
B
)
ptrC
=
get_ptr
(
out
)
ptrRowScale
=
get_ptr
(
row_scale
)
k
=
shapeA
[
-
1
]
lda
=
ct
.
c_int32
(
m
*
32
)
if
formatB
==
'col_turing'
:
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
ldb
=
ct
.
c_int32
(((
rows
+
7
)
//
8
)
*
8
*
32
)
else
:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
ldb
=
ct
.
c_int32
(((
rows
+
31
)
//
32
)
*
32
*
32
)
ldc
=
ct
.
c_int32
(
m
*
32
)
m
=
ct
.
c_int32
(
m
)
n
=
ct
.
c_int32
(
n
)
k
=
ct
.
c_int32
(
k
)
has_error
=
0
if
formatB
==
'col_turing'
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_turing_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
row_scale
is
None
:
has_error
=
lib
.
cigemmlt_turing_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
has_error
=
lib
.
cigemmlt_turing_8_rowscale
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
formatB
==
'col_ampere'
:
if
dtype
==
torch
.
int32
:
has_error
=
lib
.
cigemmlt_ampere_32
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
elif
row_scale
is
None
:
has_error
=
lib
.
cigemmlt_ampere_8
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
else
:
has_error
=
lib
.
cigemmlt_ampere_8_rowscale
(
ptr
,
m
,
n
,
k
,
ptrA
,
ptrB
,
ptrC
,
ptrRowScale
,
lda
,
ldb
,
ldc
)
if
has_error
==
1
:
raise
Exception
(
'cublasLt ran into an error!'
)
torch
.
cuda
.
set_device
(
prev_device
)
return
out
,
Sout
def
mm_dequant
(
A
,
quant_state
,
row_stats
,
col_stats
,
out
=
None
,
new_row_stats
=
None
,
new_col_stats
=
None
):
assert
A
.
dtype
==
torch
.
int32
out_shape
=
quant_state
[
0
]
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
if
out
is
None
:
out
=
torch
.
empty
(
out_shape
,
dtype
=
torch
.
float16
,
device
=
A
.
device
)
if
new_row_stats
is
None
:
new_row_stats
=
torch
.
empty
(
out_shape
[
0
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
if
new_col_stats
is
None
:
new_col_stats
=
torch
.
empty
(
out_shape
[
1
],
dtype
=
torch
.
float32
,
device
=
A
.
device
)
assert
new_row_stats
.
shape
[
0
]
==
row_stats
.
shape
[
0
],
f
"
{
new_row_stats
.
shape
}
vs
{
row_stats
.
shape
}
"
assert
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
],
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
ptrRowStats
=
get_ptr
(
row_stats
)
ptrColStats
=
get_ptr
(
col_stats
)
ptrNewRowStats
=
get_ptr
(
new_row_stats
)
ptrNewColStats
=
get_ptr
(
new_col_stats
)
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
)
return
out
def
get_colrow_absmax
(
A
,
row_stats
=
None
,
col_stats
=
None
,
nnz_block_ptr
=
None
,
threshold
=
0.0
):
assert
A
.
dtype
==
torch
.
float16
device
=
A
.
device
cols
=
A
.
shape
[
-
1
]
if
len
(
A
.
shape
)
==
3
:
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
else
:
rows
=
A
.
shape
[
0
]
col_tiles
=
(
cols
+
255
)
//
256
tiled_rows
=
((
rows
+
15
)
//
16
)
*
16
if
row_stats
is
None
:
row_stats
=
torch
.
empty
((
rows
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
col_stats
is
None
:
col_stats
=
torch
.
empty
((
cols
,),
dtype
=
torch
.
float32
,
device
=
device
).
fill_
(
-
50000.0
)
if
nnz_block_ptr
is
None
and
threshold
>
0.0
:
nnz_block_ptr
=
torch
.
zeros
(((
tiled_rows
*
col_tiles
)
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
ptrA
=
get_ptr
(
A
)
ptrRowStats
=
get_ptr
(
row_stats
)
ptrColStats
=
get_ptr
(
col_stats
)
ptrNnzrows
=
get_ptr
(
nnz_block_ptr
)
rows
=
ct
.
c_int32
(
rows
)
cols
=
ct
.
c_int32
(
cols
)
prev_device
=
pre_call
(
A
.
device
)
lib
.
cget_col_row_stats
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrNnzrows
,
ct
.
c_float
(
threshold
),
rows
,
cols
)
post_call
(
prev_device
)
if
threshold
>
0.0
:
nnz_block_ptr
.
cumsum_
(
0
)
return
row_stats
,
col_stats
,
nnz_block_ptr
class
COOSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
):
assert
rowidx
.
dtype
==
torch
.
int32
assert
colidx
.
dtype
==
torch
.
int32
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
rowidx
.
numel
()
==
nnz
assert
colidx
.
numel
()
==
nnz
self
.
rows
=
rows
self
.
cols
=
cols
self
.
nnz
=
nnz
self
.
rowidx
=
rowidx
self
.
colidx
=
colidx
self
.
values
=
values
class
CSRSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
rowptr
,
colidx
,
values
):
assert
rowptr
.
dtype
==
torch
.
int32
assert
colidx
.
dtype
==
torch
.
int32
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
colidx
.
numel
()
==
nnz
assert
rowptr
.
numel
()
==
rows
+
1
self
.
rows
=
rows
self
.
cols
=
cols
self
.
nnz
=
nnz
self
.
rowptr
=
rowptr
self
.
colidx
=
colidx
self
.
values
=
values
class
CSCSparseTensor
(
object
):
def
__init__
(
self
,
rows
,
cols
,
nnz
,
colptr
,
rowidx
,
values
):
assert
colptr
.
dtype
==
torch
.
int32
assert
rowidx
.
dtype
==
torch
.
int32
assert
values
.
dtype
==
torch
.
float16
assert
values
.
numel
()
==
nnz
assert
rowidx
.
numel
()
==
nnz
assert
colptr
.
numel
()
==
cols
+
1
self
.
rows
=
rows
self
.
cols
=
cols
self
.
nnz
=
nnz
self
.
colptr
=
colptr
self
.
rowidx
=
rowidx
self
.
values
=
values
def
coo2csr
(
cooA
):
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
values
.
add_
(
1
)
rowptr
=
torch
.
zeros
((
cooA
.
rows
+
1
,
),
dtype
=
torch
.
int32
,
device
=
cooA
.
rowidx
.
device
)
rowptr
.
scatter_
(
index
=
values
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
rowptr
.
cumsum_
(
0
)
return
CSRSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
rowptr
,
cooA
.
colidx
,
cooA
.
values
)
def
coo2csc
(
cooA
):
val
,
col2rowidx
=
torch
.
sort
(
cooA
.
colidx
)
rowidx
=
cooA
.
rowidx
[
col2rowidx
]
values
=
cooA
.
values
[
col2rowidx
]
colvalues
,
counts
=
torch
.
unique
(
val
,
return_counts
=
True
)
colvalues
.
add_
(
1
)
colptr
=
torch
.
zeros
((
cooA
.
cols
+
1
,
),
dtype
=
torch
.
int32
,
device
=
cooA
.
colidx
.
device
)
colptr
.
scatter_
(
index
=
colvalues
.
long
(),
src
=
counts
.
int
(),
dim
=
0
)
colptr
.
cumsum_
(
0
)
return
CSCSparseTensor
(
cooA
.
rows
,
cooA
.
cols
,
cooA
.
nnz
,
colptr
,
rowidx
,
values
)
def
coo_zeros
(
rows
,
cols
,
nnz
,
device
,
dtype
=
torch
.
half
):
rowidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
colidx
=
torch
.
zeros
((
nnz
,),
dtype
=
torch
.
int32
,
device
=
device
)
values
=
torch
.
zeros
((
nnz
,),
dtype
=
dtype
,
device
=
device
)
return
COOSparseTensor
(
rows
,
cols
,
nnz
,
rowidx
,
colidx
,
values
)
def
double_quant
(
A
,
col_stats
=
None
,
row_stats
=
None
,
out_col
=
None
,
out_row
=
None
,
threshold
=
0.0
):
device
=
A
.
device
assert
A
.
dtype
==
torch
.
half
assert
device
.
type
==
'cuda'
prev_device
=
pre_call
(
A
.
device
)
cols
=
A
.
shape
[
-
1
]
if
len
(
A
.
shape
)
==
3
:
rows
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
else
:
rows
=
A
.
shape
[
0
]
if
row_stats
is
None
or
col_stats
is
None
:
row_stats
,
col_stats
,
nnz_row_ptr
=
get_colrow_absmax
(
A
,
threshold
=
threshold
)
if
out_col
is
None
:
out_col
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
if
out_row
is
None
:
out_row
=
torch
.
zeros
(
A
.
shape
,
device
=
device
,
dtype
=
torch
.
int8
)
coo_tensor
=
None
ptrA
=
get_ptr
(
A
)
ptrColStats
=
get_ptr
(
col_stats
)
ptrRowStats
=
get_ptr
(
row_stats
)
ptrOutCol
=
get_ptr
(
out_col
)
ptrOutRow
=
get_ptr
(
out_row
)
if
threshold
>
0.0
:
nnz
=
nnz_row_ptr
[
-
1
].
item
()
if
nnz
>
0
:
coo_tensor
=
coo_zeros
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz_row_ptr
[
-
1
].
item
(),
device
)
ptrRowIdx
=
get_ptr
(
coo_tensor
.
rowidx
)
ptrColIdx
=
get_ptr
(
coo_tensor
.
colidx
)
ptrVal
=
get_ptr
(
coo_tensor
.
values
)
ptrRowPtr
=
get_ptr
(
nnz_row_ptr
)
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
ptrRowIdx
,
ptrColIdx
,
ptrVal
,
ptrRowPtr
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
val
,
idx
=
torch
.
sort
(
coo_tensor
.
rowidx
)
coo_tensor
.
rowidx
=
val
coo_tensor
.
colidx
=
coo_tensor
.
colidx
[
idx
]
coo_tensor
.
values
=
coo_tensor
.
values
[
idx
]
else
:
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
0.0
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
else
:
lib
.
cdouble_rowcol_quant
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOutCol
,
ptrOutRow
,
None
,
None
,
None
,
None
,
ct
.
c_float
(
threshold
),
ct
.
c_int32
(
rows
),
ct
.
c_int32
(
cols
))
post_call
(
prev_device
)
return
out_row
,
out_col
,
row_stats
,
col_stats
,
coo_tensor
def
get_special_format_str
():
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
if
major
<
7
:
print
(
f
'Device with CUDA capability of
{
major
}
not supported for 8-bit matmul. Device has no tensor cores!'
)
assert
major
>=
7
if
major
==
7
:
return
'col_turing'
elif
major
==
8
:
return
'col_ampere'
else
:
return
'col_turing'
def
transform
(
A
,
to_order
,
from_order
=
'row'
,
out
=
None
,
transpose
=
False
,
state
=
None
,
ld
=
None
):
if
state
is
None
:
state
=
(
A
.
shape
,
from_order
)
else
:
from_order
=
state
[
1
]
if
out
is
None
:
out
,
new_state
=
get_transform_buffer
(
state
[
0
],
A
.
dtype
,
A
.
device
,
to_order
,
state
[
1
],
transpose
)
else
:
new_state
=
(
state
[
0
],
to_order
)
# (shape, order)
shape
=
state
[
0
]
if
len
(
shape
)
==
2
:
dim1
=
ct
.
c_int32
(
shape
[
0
])
dim2
=
ct
.
c_int32
(
shape
[
1
])
else
:
dim1
=
ct
.
c_int32
(
shape
[
0
]
*
shape
[
1
])
dim2
=
ct
.
c_int32
(
shape
[
2
])
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
if
to_order
==
'col32'
:
if
transpose
:
lib
.
ctransform_row2col32T
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
lib
.
ctransform_row2col32
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'col_turing'
:
if
transpose
:
lib
.
ctransform_row2turingT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
lib
.
ctransform_row2turing
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'col_ampere'
:
if
transpose
:
lib
.
ctransform_row2ampereT
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
lib
.
ctransform_row2ampere
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
to_order
==
'row'
:
if
from_order
==
'col_turing'
:
lib
.
ctransform_turing2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
elif
from_order
==
'col_ampere'
:
lib
.
ctransform_ampere2row
(
get_ptr
(
A
),
get_ptr
(
out
),
dim1
,
dim2
)
else
:
raise
NotImplementedError
(
f
'Transform function not implemented: From
{
from_order
}
to
{
to_order
}
'
)
return
out
,
new_state
def
spmm_coo
(
cooA
,
B
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
empty
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
B
.
dtype
)
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
cols
==
B
.
shape
[
0
]
transposed_B
=
(
False
if
B
.
is_contiguous
()
else
True
)
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
B
.
shape
[
1
]
ptr
=
Cusparse_Context
.
get_instance
().
context
ptrRowidx
=
get_ptr
(
cooA
.
rowidx
)
ptrColidx
=
get_ptr
(
cooA
.
colidx
)
ptrValues
=
get_ptr
(
cooA
.
values
)
ptrB
=
get_ptr
(
B
)
ptrC
=
get_ptr
(
out
)
cnnz
=
ct
.
c_int32
(
cooA
.
nnz
)
crowsA
=
ct
.
c_int32
(
cooA
.
rows
)
ccolsA
=
ct
.
c_int32
(
cooA
.
cols
)
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
lib
.
cspmm_coo
(
ptr
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
cnnz
,
crowsA
,
ccolsA
,
ccolsB
,
cldb
,
ptrB
,
cldc
,
ptrC
,
ct
.
c_bool
(
transposed_B
))
return
out
def
spmm_coo_very_sparse
(
cooA
,
B
,
dequant_stats
=
None
,
out
=
None
):
if
out
is
None
:
out
=
torch
.
zeros
((
cooA
.
rows
,
B
.
shape
[
1
]),
device
=
B
.
device
,
dtype
=
cooA
.
values
.
dtype
)
nnz
=
cooA
.
nnz
assert
cooA
.
rowidx
.
numel
()
==
nnz
assert
cooA
.
colidx
.
numel
()
==
nnz
assert
cooA
.
values
.
numel
()
==
nnz
assert
cooA
.
cols
==
B
.
shape
[
0
],
f
'
{
cooA
.
cols
}
vs
{
B
.
shape
}
'
transposed_B
=
(
False
if
B
.
is_contiguous
()
else
True
)
ldb
=
B
.
stride
()[(
1
if
transposed_B
else
0
)]
ldc
=
B
.
shape
[
1
]
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
offset
=
counts
.
cumsum
(
0
).
int
()
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
max_idx
=
max_idx
.
int
()
max_count
=
max_count
.
int
()
assert
max_count
[
0
]
<=
32
,
f
'Current max count per row is 8 but found
{
max_count
[
0
]
}
.'
assert
B
.
dtype
in
[
torch
.
float16
,
torch
.
int8
]
ptrOffset
=
get_ptr
(
offset
)
ptrMaxCount
=
get_ptr
(
max_count
)
ptrMaxIdx
=
get_ptr
(
max_idx
)
ptrRowidx
=
get_ptr
(
cooA
.
rowidx
)
ptrColidx
=
get_ptr
(
cooA
.
colidx
)
ptrValues
=
get_ptr
(
cooA
.
values
)
ptrB
=
get_ptr
(
B
)
ptrC
=
get_ptr
(
out
)
ptrDequantStats
=
get_ptr
(
dequant_stats
)
cnnz_rows
=
ct
.
c_int32
(
counts
.
numel
())
cnnz
=
ct
.
c_int32
(
cooA
.
nnz
)
crowsA
=
ct
.
c_int32
(
cooA
.
rows
)
ccolsA
=
ct
.
c_int32
(
cooA
.
cols
)
crowsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
#print(cooA.rowidx[:64])
#print(cooA.colidx[:64].sort()[0])
if
B
.
dtype
==
torch
.
float16
:
lib
.
cspmm_coo_very_sparse_naive_fp16
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
)
elif
B
.
dtype
==
torch
.
int8
:
lib
.
cspmm_coo_very_sparse_naive_int8
(
ptrMaxCount
,
ptrMaxIdx
,
ptrOffset
,
ptrRowidx
,
ptrColidx
,
ptrValues
,
ptrB
,
ptrC
,
ptrDequantStats
,
cnnz_rows
,
cnnz
,
crowsA
,
crowsB
,
ccolsB
)
#else: assertion error
return
out
C
=
127.0
def
vectorwise_quant
(
x
,
dim
=
1
,
quant_type
=
'vector'
):
if
quant_type
==
'linear'
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
in
[
'vector'
,
'row'
]:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
*
(
C
/
max1
)).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'zeropoint'
:
dtype
=
x
.
dtype
x
=
x
.
float
()
dyna
=
x
.
max
()
-
x
.
min
()
if
dyna
==
0
:
dyna
=
1
qx
=
255.
/
dyna
minx
=
x
.
min
()
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
return
x
,
qx
elif
quant_type
in
[
'vector-zeropoint'
,
'row-zeropoint'
]:
dtype
=
x
.
dtype
x
=
x
.
float
()
dyna
=
(
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
)
-
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
))
dyna
[
dyna
==
0
]
=
1
qx
=
255.
/
dyna
minx
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
)
zpx
=
torch
.
round
(
minx
*
qx
)
x
=
torch
.
round
(
qx
*
x
-
zpx
)
+
zpx
return
x
,
qx
elif
quant_type
==
'truncated-vector'
:
with
torch
.
no_grad
():
absx
=
torch
.
abs
(
x
)
max1
=
torch
.
amax
(
absx
,
dim
=
dim
,
keepdim
=
True
)
max1
=
max1
*
0.7
idx
=
(
absx
>
max1
.
expand_as
(
absx
))
sign
=
torch
.
sign
(
x
[
idx
])
x
[
idx
]
=
max1
.
expand_as
(
absx
)[
idx
]
*
sign
xq
=
torch
.
round
(
x
/
max1
*
C
).
to
(
torch
.
int8
)
return
xq
,
max1
else
:
return
None
def
vectorwise_dequant
(
xq
,
max1
,
quant_type
=
'vector'
):
if
quant_type
==
'vector'
:
x
=
(
xq
/
C
*
max1
).
to
(
torch
.
float32
)
return
x
else
:
return
None
def
vectorwise_mm_dequant
(
xq
,
S1
,
S2
,
dtype
=
torch
.
half
,
quant_type
=
'vector'
):
if
quant_type
==
'linear'
:
norm
=
S1
*
S2
/
(
C
*
C
)
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'zeropoint'
:
norm
=
1.0
/
(
S1
*
S2
)
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'row-zeropoint'
:
norm
=
1.0
/
(
S1
*
S2
)
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
norm
else
:
x
*=
norm
return
x
.
to
(
dtype
)
elif
quant_type
==
'vector-zeropoint'
:
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
1.0
/
S1
else
:
x
*=
1.0
/
S1
x
*=
1.0
/
S2
.
t
()
return
x
.
to
(
dtype
)
elif
quant_type
==
'row'
:
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
*
S2
/
(
C
*
C
)
else
:
x
*=
S1
*
S2
/
(
C
*
C
)
return
x
.
to
(
dtype
)
elif
quant_type
in
[
'truncated-vector'
,
'vector'
]:
x
=
xq
.
float
()
if
len
(
S1
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
S2
.
shape
)
==
3
and
len
(
x
.
shape
)
==
2
:
S2
=
S2
.
squeeze
(
0
)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
/
C
else
:
x
*=
S1
/
C
x
*=
S2
/
C
return
x
.
to
(
dtype
)
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
=
torch
.
half
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
else
:
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
return
x
.
to
(
dtype
)
bitsandbytes/nn/__init__.py
View file @
c771b3a7
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
StableEmbedding
,
Embedding
from
.modules
import
StableEmbedding
,
Linear8bit
,
Linear8bitLt
,
Int8Params
bitsandbytes/nn/modules.py
View file @
c771b3a7
...
@@ -3,14 +3,19 @@
...
@@ -3,14 +3,19 @@
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
torch
import
torch
import
bitsandbytes
as
bnb
from
typing
import
Optional
from
typing
import
Union
,
Tuple
,
Any
,
Callable
,
Iterator
,
Set
,
Optional
,
overload
,
TypeVar
,
Mapping
,
Dict
from
torch
import
Tensor
from
torch
import
Tensor
,
device
,
dtype
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.optim
import
GlobalOptimManager
T
=
TypeVar
(
'T'
,
bound
=
'torch.nn.Module'
)
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
class
StableEmbedding
(
torch
.
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
Optional
[
int
]
=
None
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
max_norm
:
Optional
[
float
]
=
None
,
norm_type
:
float
=
2.
,
scale_grad_by_freq
:
bool
=
False
,
...
@@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
...
@@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
return
emb
return
emb
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
has_fp16_weights
=
False
,
CB
=
None
,
SCB
=
None
):
cls
.
has_fp16_weights
=
has_fp16_weights
cls
.
CB
=
None
cls
.
SCB
=
None
if
data
is
None
:
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
def
cuda
(
self
,
device
):
if
self
.
has_fp16_weights
:
return
super
().
cuda
(
device
)
else
:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B
)
del
CBt
del
SCBt
self
.
data
=
CB
setattr
(
self
,
'CB'
,
CB
)
setattr
(
self
,
'SCB'
,
SCB
)
return
self
@
overload
def
to
(
self
:
T
,
device
:
Optional
[
Union
[
int
,
device
]]
=
...,
dtype
:
Optional
[
Union
[
dtype
,
str
]]
=
...,
non_blocking
:
bool
=
...)
->
T
:
...
@
overload
def
to
(
self
:
T
,
dtype
:
Union
[
dtype
,
str
],
non_blocking
:
bool
=
...)
->
T
:
...
@
overload
def
to
(
self
:
T
,
tensor
:
Tensor
,
non_blocking
:
bool
=
...)
->
T
:
...
def
to
(
self
,
*
args
,
**
kwargs
):
device
,
dtype
,
non_blocking
,
convert_to_format
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
device
is
not
None
and
device
.
type
==
'cuda'
and
self
.
data
.
device
.
type
==
'cpu'
:
return
self
.
cuda
(
device
)
else
:
new_param
=
Int8Params
(
super
().
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
),
requires_grad
=
self
.
requires_grad
,
has_fp16_weights
=
self
.
has_fp16_weights
)
new_param
.
CB
=
self
.
CB
new_param
.
SCB
=
self
.
SCB
return
new_param
class
Linear8bitLt
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
threshold
=
0.0
,
index
=
None
):
super
(
Linear8bitLt
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
index
=
index
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
weight
=
Int8Params
(
self
.
weight
.
data
,
has_fp16_weights
=
has_fp16_weights
)
def
init_8bit_state
(
self
):
self
.
state
.
CB
=
self
.
weight
.
CB
self
.
state
.
SCB
=
self
.
weight
.
SCB
self
.
weight
.
CB
=
None
self
.
weight
.
SCB
=
None
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
weight
.
CB
is
not
None
:
self
.
init_8bit_state
()
#assert not self.state.has_fp16_weights
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
state
=
self
.
state
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
.
unsqueeze
(
0
).
expand_as
(
out
)
if
not
self
.
state
.
has_fp16_weights
and
self
.
state
.
CB
is
not
None
:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del
self
.
state
.
CB
self
.
weight
.
data
=
self
.
state
.
CxB
return
out
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
quant_type
=
'vector'
,
index
=
None
,
args
=
None
,
sparse_decomp
=
False
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
quant_type
=
quant_type
self
.
index
=
index
self
.
args
=
args
self
.
iter
=
0
def
forward
(
self
,
x
):
self
.
iter
+=
1
if
self
.
iter
%
self
.
args
.
clip_freq
==
0
:
with
torch
.
no_grad
():
maxval
,
maxidx
=
torch
.
topk
(
torch
.
abs
(
self
.
weight
.
flatten
()),
k
=
self
.
args
.
clip_idx
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
'clip'
,
maxval
[
-
1
].
item
())
self
.
weight
.
clip_
(
-
maxval
[
-
1
],
maxval
[
-
1
])
if
self
.
args
is
not
None
:
out
=
bnb
.
nn
.
functional
.
sparse_decomposed_linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
qval
=
self
.
args
.
sparse_decomp_val
,
quant_type
=
self
.
args
.
quant_type
)
else
:
out
=
bnb
.
nn
.
functional
.
linear8bit
(
x
,
self
.
weight
,
self
.
bias
,
quant_type
=
self
.
args
.
quant_type
)
return
out
csrc/kernels.cu
View file @
c771b3a7
...
@@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
}
}
}
}
template
<
typename
T
,
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kgetColRowStats
(
T
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
)
{
// 0. reset stats to -FLT_MAX
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
// 2. compute col max (per thread); store in smem due to register pressure
// 3. compute row max (per block); store in smem to accumulate full global mem transation
// 4. store data via atomicMax
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const
int
base_row
=
((
blockIdx
.
x
*
TILE_COLS
)
/
tiledCols
)
*
TILE_ROWS
;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const
int
base_col
=
(
blockIdx
.
x
*
TILE_COLS
)
%
tiledCols
;
const
int
base_idx
=
(
base_row
*
cols
)
+
base_col
;
const
int
items_per_load
=
ITEMS_PER_THREAD
*
THREADS
;
typedef
cub
::
BlockLoad
<
T
,
THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_VECTORIZE
>
LoadT
;
typedef
cub
::
BlockReduce
<
float
,
THREADS
>
BlockRowReduce
;
typedef
cub
::
BlockReduce
<
int
,
THREADS
>
BlockRowSum
;
typedef
cub
::
BlockExchange
<
float
,
THREADS
,
ITEMS_PER_THREAD
>
BlockExchange
;
__shared__
union
{
typename
BlockExchange
::
TempStorage
exchange
;
typename
BlockRowReduce
::
TempStorage
rowreduce
;
typename
BlockRowSum
::
TempStorage
rowsum
;
typename
LoadT
::
TempStorage
loadt
;
}
temp_storage
;
__shared__
float
smem_row_absmax_values
[
ITEMS_PER_THREAD
*
THREADS
];
__shared__
int
smem_row_nnz_values
[
TILE_ROWS
];
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half
local_data
[
ITEMS_PER_THREAD
];
float
local_data_fp32
[
ITEMS_PER_THREAD
];
float
local_col_absmax_values
[
ITEMS_PER_THREAD
];
int
local_row_nnz_count
=
0
;
float
row_absmax
=
-
FLT_MAX
;
// 0. reset stats to -FLT_MAX
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
smem_row_absmax_values
[
threadIdx
.
x
+
(
j
*
THREADS
)]
=
-
FLT_MAX
;
smem_row_nnz_values
[
threadIdx
.
x
+
(
j
*
THREADS
)]
=
0
;
}
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_col_absmax_values
[
j
]
=
-
FLT_MAX
;
__syncthreads
();
int
valid_items
=
cols
-
base_col
>
items_per_load
?
items_per_load
:
cols
-
base_col
;
int
i
=
base_idx
;
// we load row after row from the base_position
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
for
(
int
row
=
0
;
row
<
TILE_ROWS
;
row
++
)
{
if
(
base_row
+
row
>=
rows
){
break
;
}
local_row_nnz_count
=
0
;
i
=
base_idx
+
((
row
)
*
cols
);
// each thread gets data from the same column
__syncthreads
();
LoadT
(
temp_storage
.
loadt
).
Load
(
&
(
A
[
i
]),
local_data
,
valid_items
,
__float2half
(
0.0
f
));
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_data
[
j
]
=
fabsf
(
local_data
[
j
]);
if
(
SPARSE_DECOMP
)
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
if
((
float
)
local_data
[
j
]
>=
nnz_threshold
)
{
local_row_nnz_count
+=
1
;
local_data
[
j
]
=
0.0
f
;
}
}
// 2. compute col max (per thread); store in smem due to register pressure
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
// take the col max for this row
// we use shared memory because register pressure is too high if we do this locally
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
local_col_absmax_values
[
j
]
=
fmaxf
(
local_col_absmax_values
[
j
],
__half2float
(
local_data
[
j
]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
__syncthreads
();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_data_fp32
[
j
]
=
local_data
[
j
];
row_absmax
=
(
float
)
BlockRowReduce
(
temp_storage
.
rowreduce
).
Reduce
(
local_data_fp32
,
cub
::
Max
());
if
(
SPARSE_DECOMP
)
{
__syncthreads
();
local_row_nnz_count
=
BlockRowSum
(
temp_storage
.
rowsum
).
Sum
(
local_row_nnz_count
);
}
// we store the data temporarily in shared memory so we
// can execute a full atomic block transaction into global memory later
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
if
(
threadIdx
.
x
==
0
)
{
smem_row_absmax_values
[(
row
%
ITEMS_PER_THREAD
)
+
((
row
/
ITEMS_PER_THREAD
)
*
ITEMS_PER_THREAD
)]
=
row_absmax
;
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
smem_row_nnz_values
[
row
]
=
local_row_nnz_count
;
}
__syncthreads
();
}
// 4. store data via atomicMax
// to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
// into a striped arangement: [0, 8, 16, 24, ..] for t0
__syncthreads
();
BlockExchange
(
temp_storage
.
exchange
).
BlockedToStriped
(
local_col_absmax_values
);
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
if
(
base_col
+
threadIdx
.
x
+
(
j
*
THREADS
)
<
cols
)
{
float
val
=
colStats
[
base_col
+
(
threadIdx
.
x
+
(
j
*
THREADS
))];
if
(
val
<
local_col_absmax_values
[
j
])
atomicMax
(
&
colStats
[
base_col
+
(
threadIdx
.
x
+
(
j
*
THREADS
))],
local_col_absmax_values
[
j
]);
}
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
if
(
base_row
+
threadIdx
.
x
+
(
j
*
THREADS
)
<
rows
)
{
float
val
=
rowStats
[
base_row
+
(
threadIdx
.
x
+
(
j
*
THREADS
))];
if
(
val
<
smem_row_absmax_values
[
threadIdx
.
x
+
(
j
*
THREADS
)])
atomicMax
(
&
rowStats
[
base_row
+
(
threadIdx
.
x
+
(
j
*
THREADS
))],
smem_row_absmax_values
[
threadIdx
.
x
+
(
j
*
THREADS
)]);
}
if
(
SPARSE_DECOMP
)
if
(
threadIdx
.
x
<
TILE_ROWS
)
nnz_count_row
[
blockIdx
.
x
*
TILE_ROWS
+
threadIdx
.
x
+
1
]
=
smem_row_nnz_values
[
threadIdx
.
x
];
}
template
__global__
void
kgetColRowStats
<
half
,
64
,
4
,
16
,
64
*
4
,
0
>(
half
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
);
template
__global__
void
kgetColRowStats
<
half
,
64
,
4
,
16
,
64
*
4
,
1
>(
half
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
)
{
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
// since different row/col stats need to be loaded with each thread.
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
// and would lead to low global load utilization.
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
// This allows for efficient row/col loading from shared memory within the tile.
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
// shared memory loads.
// data is in 32 column-tile major with tile width 32 columns and numRows rows
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
// C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
// C2. Compute normalization values and store col values in register
// S1. Store C1 into 16-bit output
// S2. Store col/row statistics of new buffer in shared memory
// We allow for sub-tiles to span multiple col32 tiles. This is okay
// since the items per thread only rely on a single column statistic.
const
int
n_out
=
numRows
*
numCols
;
int
num_row_tiles
=
(
numRows
/
SUBTILE_ROWS
)
+
(
numRows
%
SUBTILE_ROWS
==
0
?
0
:
1
);
// we have tiles of size numRows*32, thus col only increases every numRows
// num_row_tiles is the tiles after which the column increases by 32
// blockIdx.x is the index of the current tile
int
col
=
((
threadIdx
.
x
%
32
)
+
((
blockIdx
.
x
/
num_row_tiles
)
*
32
));
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
int
base_row
=
(
blockIdx
.
x
*
SUBTILE_ROWS
)
%
(
num_row_tiles
*
SUBTILE_ROWS
);
// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
// Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
// For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
// 1024*1024/(128*32) = 256 tiles
// 256 tiles are 256*128*32/4 = 256*1024 threads
// 1. Figure out how index relates to the start of the sub-tile
// 2. Each thread < SUBTILE_ROWS calculates row index
// 3. Load striped and store in shared memory
int
local_values
[
ITEMS_PER_THREAD
];
half
local_output
[
ITEMS_PER_THREAD
];
float
local_rowStats
[
ITEMS_PER_THREAD
];
__shared__
float
smem_rowStats
[
SUBTILE_ROWS
];
typedef
cub
::
BlockLoad
<
int
,
THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_DIRECT
>
LoadInt32
;
typedef
cub
::
BlockExchange
<
int
,
THREADS
,
ITEMS_PER_THREAD
>
ExchangeInt32
;
__shared__
typename
LoadInt32
::
TempStorage
loadint32
;
__shared__
typename
ExchangeInt32
::
TempStorage
exchangeint32
;
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float
colStat
=
col
>=
numCols
?
0.0
f
:
colStats
[
col
];
// no block loads for rows for now -- keep it simple
for
(
int
j
=
threadIdx
.
x
;
j
<
SUBTILE_ROWS
;
j
+=
blockDim
.
x
)
{
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
int
row
=
(
base_row
+
j
)
%
numRows
;
// wrap around
// each warp accesses the same element, for four consequitive elements
// todo: update description about striped shared memory, it is not needed
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
smem_rowStats
[
j
]
=
rowStats
[
row
];
}
__syncthreads
();
// each block processes SUBTILE_ROWS*32 elements
const
int
items_per_load
=
THREADS
*
ITEMS_PER_THREAD
;
const
int
rows_per_load
=
items_per_load
/
32
;
int
subtile_base_row
=
(
threadIdx
.
x
/
32
)
*
ITEMS_PER_THREAD
;
// row within the tile
int
row_offset
=
0
;
// subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
int
subtile_start
=
(
blockIdx
.
x
/
num_row_tiles
)
*
(
numRows
*
32
)
+
(
base_row
*
32
);
for
(
int
subtile_idx
=
subtile_start
;
subtile_idx
<
subtile_start
+
(
SUBTILE_ROWS
*
32
);
subtile_idx
+=
items_per_load
)
{
int
valid_rows
=
numRows
-
(
base_row
+
row_offset
)
>
rows_per_load
?
rows_per_load
:
numRows
-
(
base_row
+
row_offset
);
int
valid_items
=
valid_rows
*
32
;
if
(
valid_items
<=
0
)
// the sub-tile might have more elements than the tile itself
break
;
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
LoadInt32
(
loadint32
).
Load
(
&
(
A
[
subtile_idx
]),
local_values
,
valid_items
,
0
);
ExchangeInt32
(
exchangeint32
).
BlockedToWarpStriped
(
local_values
,
local_values
);
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_rowStats
[
j
]
=
smem_rowStats
[
subtile_base_row
+
row_offset
+
j
];
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_output
[
j
]
=
__float2half
(
local_values
[
j
]
*
MM_DEQUANT_CONST
*
local_rowStats
[
j
]
*
colStat
);
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
// we store data in row major
// to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
// so that each thread holds ITEMS_PER_THREAD consecutive items for each row
// this way throughput into storage is increased by a factor of ~2x
// for now we use a simple store
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
int
outIdx
=
col
+
((
base_row
+
subtile_base_row
+
row_offset
+
j
)
*
numCols
);
if
(
outIdx
<
n_out
&&
col
<
numCols
)
out
[
outIdx
]
=
local_output
[
j
];
}
row_offset
+=
rows_per_load
;
}
}
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kDoubleRowColQuant
(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
)
{
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
// Each thread reads the same column but multiple rows
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
// 2. quantize data with row/col stats
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const
int
base_row
=
((
blockIdx
.
x
*
TILE_COLS
)
/
tiledCols
)
*
TILE_ROWS
;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const
int
base_col
=
(
blockIdx
.
x
*
TILE_COLS
)
%
tiledCols
;
const
int
base_idx
=
(
base_row
*
cols
)
+
base_col
;
const
int
items_per_load
=
ITEMS_PER_THREAD
*
THREADS
;
typedef
cub
::
BlockLoad
<
half
,
THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_LOAD_VECTORIZE
>
LoadHalf
;
__shared__
typename
LoadHalf
::
TempStorage
loadhalf
;
typedef
cub
::
BlockStore
<
char
,
THREADS
,
ITEMS_PER_THREAD
,
cub
::
BLOCK_STORE_VECTORIZE
>
StoreInt8
;
__shared__
typename
StoreInt8
::
TempStorage
storeint8
;
__shared__
float
smem_row_stats
[
TILE_ROWS
];
__shared__
unsigned
int
smem_nnz_row_idx
[
TILE_ROWS
];
half
local_data
[
ITEMS_PER_THREAD
];
float
local_col_stats
[
ITEMS_PER_THREAD
];
char
local_quantized_data
[
ITEMS_PER_THREAD
];
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
if
(
base_col
+
(
threadIdx
.
x
*
ITEMS_PER_THREAD
)
+
j
<
cols
)
local_col_stats
[
j
]
=
__fdividef
(
127.0
f
,
colStats
[
base_col
+
(
threadIdx
.
x
*
ITEMS_PER_THREAD
)
+
j
]);
for
(
int
i
=
threadIdx
.
x
;
i
<
TILE_ROWS
;
i
+=
blockDim
.
x
)
{
if
(
base_row
+
i
<
rows
)
smem_row_stats
[
i
]
=
rowStats
[
base_row
+
i
];
if
(
SPARSE_DECOMP
)
smem_nnz_row_idx
[
i
]
=
nnz_block_ptr
[(
TILE_ROWS
*
blockIdx
.
x
)
+
i
];
}
__syncthreads
();
// we load row after row from the base_position
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
for
(
int
row
=
0
;
row
<
TILE_ROWS
;
row
++
)
{
if
(
base_row
+
row
>=
rows
){
break
;
}
int
i
=
base_idx
+
(
row
*
cols
);
int
valid_items
=
cols
-
base_col
>
items_per_load
?
items_per_load
:
cols
-
base_col
;
LoadHalf
(
loadhalf
).
Load
(
&
(
A
[
i
]),
local_data
,
valid_items
,
0.0
f
);
float
row_stat
=
__fdividef
(
127.0
f
,
smem_row_stats
[
row
]);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
if
(
SPARSE_DECOMP
)
{
if
(
fabsf
((
float
)
local_data
[
j
])
>=
threshold
)
{
local_quantized_data
[
j
]
=
0
;
int
old_idx
=
atomicInc
(
&
smem_nnz_row_idx
[
row
],
UINT_MAX
);
rowidx
[
old_idx
]
=
base_row
+
row
;
colidx
[
old_idx
]
=
base_col
+
(
threadIdx
.
x
*
ITEMS_PER_THREAD
)
+
j
;
val
[
old_idx
]
=
local_data
[
j
];
}
else
{
local_quantized_data
[
j
]
=
(
char
)(
rintf
(
__half2float
(
local_data
[
j
])
*
row_stat
));
}
}
else
local_quantized_data
[
j
]
=
(
char
)(
rintf
(
__half2float
(
local_data
[
j
])
*
row_stat
));
}
StoreInt8
(
storeint8
).
Store
(
&
(
out_row_normed
[
i
]),
local_quantized_data
,
valid_items
);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
local_quantized_data
[
j
]
=
(
char
)(
rintf
(
__half2float
(
local_data
[
j
])
*
local_col_stats
[
j
]));
}
__syncthreads
();
StoreInt8
(
storeint8
).
Store
(
&
(
out_col_normed
[
i
]),
local_quantized_data
,
valid_items
);
}
}
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
TRANSPOSE
,
int
FORMAT
>
__global__
void
kTransformRowToFormat
(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
)
{
// 0. Load data into 32*32 shared memory tiles
// 1. transpose / reorder in shared memory
// 2. store
// COL32 FORMAT:
// rows*32 tiles
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
// As such we need:
// at least 32*4 shared memory tiles for col32; preferably 32*32
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32
// for efficient loading of row major we need to load 128 elements and repeat this 32 items
// this would imply a 32x128 shared memory tile -> 4kb
// It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
//
// to make the shared memory work with that occupancy we might need to union the block loads/stores
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const
int
base_row
=
((
blockIdx
.
x
*
TILE_COLS
)
/
tiledCols
)
*
TILE_ROWS
;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const
int
base_col
=
(
blockIdx
.
x
*
TILE_COLS
)
%
tiledCols
;
const
int
base_idx
=
(
base_row
*
cols
)
+
base_col
;
// we load 128 bytes per warp with
// 32 rows for transposes that fill col32 types
// so that we can have contiguous stores
__shared__
char
smem_data
[
32
*
33
*
ITEMS_PER_THREAD
];
char
local_data
[
ITEMS_PER_THREAD
];
typedef
cub
::
BlockExchange
<
char
,
THREADS
,
ITEMS_PER_THREAD
>
BlockExchange
;
__shared__
typename
BlockExchange
::
TempStorage
temp_storage
;
// we load row after row from the base_position
// Load data row by row
int
warps
=
blockDim
.
x
/
32
;
int
warp_id
=
threadIdx
.
x
/
32
;
int
warp_lane
=
threadIdx
.
x
%
32
;
int
offset
=
0
;
int
smem_row
=
0
;
// each warp loads one row of 128 bytes
for
(
int
row
=
warp_id
;
row
<
TILE_ROWS
;
row
+=
warps
)
{
int
i
=
base_idx
+
(
row
*
cols
);
// we load up to 128 bytes/items per load
int
valid_items
=
cols
-
base_col
>
32
*
ITEMS_PER_THREAD
?
32
*
ITEMS_PER_THREAD
:
cols
-
base_col
;
// 0. Load data into 32*32 shared memory tiles
if
(
base_row
+
row
<
rows
)
{
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
int
col_idx
=
warp_lane
+
(
j
*
32
);
if
(
col_idx
<
valid_items
)
local_data
[
j
]
=
A
[
i
+
col_idx
];
else
local_data
[
j
]
=
0
;
}
}
else
{
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_data
[
j
]
=
0
;
}
if
(
TRANSPOSE
)
{
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
int
local_col
=
(
32
*
j
)
+
warp_lane
;
//int local_row = row;
// store as 256x32
smem_data
[(
local_col
*
33
)
+
row
]
=
local_data
[
j
];
}
}
else
{
// treat smem as 32x256, that is 32 rows and 256 columns
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
smem_data
[
row
*
32
*
ITEMS_PER_THREAD
+
(
warp_lane
)
+
(
j
*
32
)]
=
local_data
[
j
];
}
smem_row
+=
warps
;
// 1. transpose / reorder in shared memory
if
(
smem_row
%
32
==
0
)
{
smem_row
=
0
;
__syncthreads
();
for
(
int
subrow
=
warp_id
;
subrow
<
32
;
subrow
+=
warps
)
{
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
{
switch
(
FORMAT
)
{
case
COL32
:
if
(
TRANSPOSE
)
{
// data lies in shared memory in the following way:
// row0 [col0 col1 ... col31]
// row1 [col0 col1 ... col31]
// ...
//
// As such we read consequtive entries with 256 threads (8rows x 32 columns)
// as j increase, the row increase by a factor of 8
// We load 8 rows per subrow loop, and subrow increase by 8 per loop
// so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
const
int
jrow
=
j
*
ITEMS_PER_THREAD
;
// 8 rows per j
const
int
subrow_loop_row
=
(
subrow
/
warps
)
*
ITEMS_PER_THREAD
*
ITEMS_PER_THREAD
;
// 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if
((
base_col
+
subrow_loop_row
+
jrow
+
warp_id
<
outRows
)
&&
(
base_row
+
warp_lane
<
rows
))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char
data
=
smem_data
[(
subrow_loop_row
+
jrow
+
warp_id
)
*
33
+
warp_lane
];
// each 32 columns we have new tile
// each tile has size outRows*32 and base_row is done in increments of 32
offset
=
base_row
*
outRows
;
out
[
offset
+
(
base_col
+
jrow
+
subrow_loop_row
)
*
32
+
threadIdx
.
x
]
=
data
;
}
}
else
{
if
(((
base_row
+
subrow
)
<
rows
)
&&
(
base_col
+
(
j
*
32
)
+
warp_lane
<
outCols
))
{
offset
=
(
base_col
/
32
)
*
(
32
*
rows
);
char
data
=
smem_data
[(
subrow
*
32
*
ITEMS_PER_THREAD
)
+
(
j
*
32
)
+
warp_lane
];
out
[
offset
+
(
base_row
+
subrow
)
*
32
+
((
j
)
*
rows
*
32
)
+
warp_lane
]
=
data
;
}
}
break
;
case
COL_TURING
:
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
//
// [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
if
(
TRANSPOSE
)
{
const
int
jrow
=
j
*
ITEMS_PER_THREAD
;
// 8 rows per j
const
int
subrow_loop_row
=
(
subrow
/
warps
)
*
ITEMS_PER_THREAD
*
ITEMS_PER_THREAD
;
// 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if
((
base_col
+
subrow_loop_row
+
jrow
+
warp_id
<
outRows
)
&&
(
base_row
+
warp_lane
<
rows
))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char
data
=
smem_data
[(
subrow_loop_row
+
jrow
+
warp_id
)
*
33
+
warp_lane
];
// each 32 columns we have new tile
// each tile has size 8*32 = 256 elements offset
// for each row offset of 8 we increaes the tile first
// after all rows are exhausted, we increase the col
int
row_offset
=
((
base_col
+
jrow
+
subrow_loop_row
+
warp_id
)
/
8
)
*
256
;
// global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 256*outRows/8*base_row/32 = outRows*base_row
int
col_offset
=
outRows
*
base_row
;
offset
=
row_offset
+
col_offset
;
// since we process even number of rows with each j (8) and with each subrow (8j) we can determine
// odd or even rows with the warp_id (each warp processes one row)
// the col is warp_lane (max 32 columns per row) and the row warp_id
if
(
warp_id
%
2
==
1
)
// odd
offset
+=
128
+
(
warp_lane
/
4
)
*
16
+
(
warp_lane
%
4
)
+
(((
warp_id
%
8
)
-
1
)
*
2
);
else
// even
offset
+=
0
+
(
warp_lane
/
4
)
*
16
+
(
warp_lane
%
4
)
+
((
warp_id
%
8
)
*
2
);
out
[
offset
]
=
data
;
}
}
else
{
if
(((
base_row
+
subrow
)
<
rows
)
&&
(
base_col
+
(
j
*
32
)
+
warp_lane
<
outCols
))
{
char
data
=
smem_data
[(
subrow
*
32
*
ITEMS_PER_THREAD
)
+
(
j
*
32
)
+
warp_lane
];
// set offset designates the tile offset among the 8*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 8*32=256 every 8 rows
offset
=
((
base_col
+
(
j
*
32
))
/
32
)
*
outRows
*
32
+
(((
base_row
+
subrow
)
/
8
)
*
256
);
// global offset (8x32 tile)
// first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
// each of these has 32 values in total for 32*4 = 128 as offset if odd
// every set of 4 columns increases the total offset by 16
// each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
// this happends every 8 rows anew (subrow % 8)
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile
int
subcol
=
warp_lane
;
// add local offset (4x4 sub-tile)
if
(
subrow
%
2
==
1
)
// odd
offset
+=
128
+
(
subcol
/
4
)
*
16
+
(
subcol
%
4
)
+
(((
subrow
%
8
)
-
1
)
*
2
);
else
// even
offset
+=
0
+
(
subcol
/
4
)
*
16
+
(
subcol
%
4
)
+
((
subrow
%
8
)
*
2
);
out
[
offset
]
=
data
;
}
}
break
;
case
COL_AMPERE
:
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
if
(
TRANSPOSE
)
{
const
int
jrow
=
j
*
ITEMS_PER_THREAD
;
// 8 rows per j
const
int
subrow_loop_row
=
(
subrow
/
warps
)
*
ITEMS_PER_THREAD
*
ITEMS_PER_THREAD
;
// 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if
((
base_col
+
subrow_loop_row
+
jrow
+
warp_id
<
outRows
)
&&
(
base_row
+
warp_lane
<
rows
))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char
data
=
smem_data
[(
subrow_loop_row
+
jrow
+
warp_id
)
*
33
+
warp_lane
];
// each 32 columns we have new tile
// each tile has size 32*32 = 1024 elements offset
// for each row offset of 32 we increaes the tile first
// after all rows are exhausted, we increase the col
int
row_offset
=
((
base_col
+
jrow
+
subrow_loop_row
+
warp_id
)
/
32
)
*
1024
;
// global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 1024*outRows/32*base_row/32 = outRows*base_row
int
col_offset
=
outRows
*
base_row
;
offset
=
row_offset
+
col_offset
;
// same as in the non-transpose case (see below)
// the difference is that now rows = cols
// in this case warp_id = subrow
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int
local_row
=
(
jrow
+
warp_id
)
%
32
;
// offset for row > 32 is already calculated into row_offset
int
ampere_row
=
((
local_row
%
8
)
/
2
)
*
8
+
(
local_row
/
8
)
*
2
+
(
local_row
%
2
);
// global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
out
[
offset
+
(
ampere_row
*
32
)
+
warp_lane
]
=
data
;
}
}
else
{
if
(((
base_row
+
subrow
)
<
rows
)
&&
(
base_col
+
(
j
*
32
)
+
warp_lane
<
outCols
))
{
char
data
=
smem_data
[(
subrow
*
32
*
ITEMS_PER_THREAD
)
+
(
j
*
32
)
+
warp_lane
];
// set offset designates the tile offset among the 32*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 32*32=1024 every 32 rows
offset
=
((
base_col
+
(
j
*
32
))
/
32
)
*
outRows
*
32
+
(((
base_row
+
subrow
)
/
32
)
*
1024
);
// global offset (32x32 tile)
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int
local_row
=
((
subrow
%
8
)
/
2
)
*
8
+
(
subrow
/
8
)
*
2
+
(
subrow
%
2
);
// global offset + row with 32 cols each + 32 cols per j + col_idx
out
[
offset
+
(
local_row
*
32
)
+
warp_lane
]
=
data
;
}
}
break
;
}
}
}
}
}
}
#define C 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template
<
typename
T
,
int
SPMM_ITEMS
,
int
BITS
>
__global__
void
kspmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer
// elements they finish faster "fillin up" the gaps left by larger blocks
// without tensor cores
// 1. use rowidx_length to find what to load (as many blocks as there are rows)
// 2. Load A into registers
// 3. each warp loads all required rows of B but each warp is offset by k
// 4. Do mma operations that accumulate into registers
// 5. Each warp stores its output row into matrix C
const
int
count
=
max_count
[
blockIdx
.
x
];
const
int
local_max_idx
=
max_idx
[
blockIdx
.
x
];
const
int
offset
=
local_max_idx
==
0
?
0
:
offset_rowidx
[
local_max_idx
-
1
];
const
int
local_row_idx
=
rowidx
[
offset
];
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
warp_idx
=
threadIdx
.
x
%
32
;
const
int
warp_offset
=
(
warp_id
*
32
)
*
SPMM_ITEMS
;
const
int
num_items
=
BITS
==
8
?
8
:
8
;
int
idx_col_B
=
warp_offset
;
int
local_idx_col_B_offset
=
0
;
half
local_valA
[
MAX_SPARSE_COUNT
];
int
local_colidxA
[
MAX_SPARSE_COUNT
];
half
local_valC
[
SPMM_ITEMS
];
T
local_valsB
[
num_items
];
half
local_valOut
[
num_items
];
// 128 byte loads per warp == 4 bytes per thread
// 2. Load A into registers
for
(
int
j
=
0
;
j
<
MAX_SPARSE_COUNT
;
j
++
)
{
local_valA
[
j
]
=
j
<
count
?
values
[
offset
+
j
]
:
__float2half
(
0.0
f
);
local_colidxA
[
j
]
=
j
<
count
?
colidx
[
offset
+
j
]
:
0
;
}
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*32 apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__
half
smem_dequant_stats
[
SMEM_SIZE
];
while
(
idx_col_B
<
colsB
)
{
if
(
dequant_stats
!=
NULL
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
SMEM_SIZE
;
i
+=
blockDim
.
x
)
if
((
idx_col_B
+
i
-
local_idx_col_B_offset
)
<
colsB
)
smem_dequant_stats
[
i
]
=
__ldg
(
&
dequant_stats
[
idx_col_B
+
i
-
local_idx_col_B_offset
]);
__syncthreads
();
}
#pragma unroll SPMM_ITEMS
for
(
int
j
=
0
;
j
<
SPMM_ITEMS
;
j
++
)
local_valC
[
j
]
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
// 3. each warp loads all required rows of B but each warp is offset by k
int
row_offset
=
colsB
*
local_colidxA
[
i
];
#pragma unroll SPMM_ITEMS
for
(
int
j
=
0
;
j
<
SPMM_ITEMS
;
j
+=
num_items
)
{
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
int
idx
=
idx_col_B
+
(
warp_idx
*
SPMM_ITEMS
)
+
j
;
if
(
idx
>=
colsB
){
break
;
}
//printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
if
((
idx
+
num_items
<
colsB
))
{
if
(
BITS
==
8
)
reinterpret_cast
<
float2
(
&
)[
num_items
]
>
(
local_valsB
)[
0
]
=
reinterpret_cast
<
float2
*>
(
B
)[(
row_offset
+
idx
)
/
num_items
];
else
reinterpret_cast
<
float4
(
&
)[
num_items
]
>
(
local_valsB
)[
0
]
=
reinterpret_cast
<
float4
*>
(
B
)[(
row_offset
+
idx
)
/
num_items
];
}
else
{
#pragma unroll num_items
for
(
int
k
=
0
;
k
<
num_items
;
k
++
)
if
(
idx
+
k
<
colsB
)
local_valsB
[
k
]
=
B
[
row_offset
+
idx
+
k
];
else
local_valsB
[
k
]
=
0.0
f
;
}
#pragma unroll num_items
for
(
int
k
=
0
;
k
<
num_items
;
k
++
)
{
//if((float)local_valsB[k] != 0.0)
// printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
if
(
BITS
==
8
&&
dequant_stats
!=
NULL
)
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
{
float
valB
=
local_valsB
[
k
];
float
valA
=
local_valA
[
i
];
if
(
valB
!=
0.0
&&
valA
!=
0.0
)
local_valC
[
j
+
k
]
=
(
float
)
local_valC
[
j
+
k
]
+
((
float
)
smem_dequant_stats
[
idx
+
k
-
local_idx_col_B_offset
])
*
C
*
valB
*
valA
;
}
else
local_valC
[
j
+
k
]
=
(
float
)
local_valC
[
j
+
k
]
+
(
float
)
local_valsB
[
k
]
*
(
float
)
local_valA
[
i
];
}
}
}
int
idx_row_C
=
(
colsB
*
local_row_idx
);
#pragma unroll SPMM_ITEMS
for
(
int
j
=
0
;
j
<
SPMM_ITEMS
;
j
+=
num_items
)
{
//int idx_col_C = idx_col_B + (32*j) + warp_idx;
int
idx_col_C
=
idx_col_B
+
warp_idx
*
SPMM_ITEMS
+
j
;
int
idx_val
=
idx_col_C
+
idx_row_C
;
if
(
idx_col_C
+
num_items
<
colsB
)
{
// load outputs to do inplace addition
reinterpret_cast
<
float4
(
&
)[
num_items
/
4
]
>
(
local_valOut
)[
0
]
=
reinterpret_cast
<
float4
*>
(
out
)[
idx_val
/
num_items
];
#pragma unroll num_items
for
(
int
k
=
0
;
k
<
num_items
;
k
++
)
local_valC
[(
j
/
num_items
)
+
k
]
=
(
float
)
local_valC
[(
j
/
num_items
)
+
k
]
+
(
float
)
local_valOut
[
k
];
reinterpret_cast
<
float4
*>
(
out
)[
idx_val
/
num_items
]
=
reinterpret_cast
<
float4
(
&
)[
num_items
]
>
(
local_valC
)[
j
/
num_items
];
}
else
{
#pragma unroll num_items
for
(
int
k
=
0
;
k
<
num_items
;
k
++
)
if
(
idx_col_C
+
k
<
colsB
)
out
[
idx_val
+
k
]
=
(
float
)
out
[
idx_val
+
k
]
+
(
float
)
local_valC
[
j
+
k
];
}
}
idx_col_B
+=
blockDim
.
x
*
SPMM_ITEMS
;
local_idx_col_B_offset
+=
blockDim
.
x
*
SPMM_ITEMS
;
}
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
8
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
16
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
32
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
8
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
16
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
32
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL32
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL32
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL_TURING
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL_TURING
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL_AMPERE
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL_AMPERE
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kdequant_mm_int32_fp16
<
4
,
128
,
512
>(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
template
__global__
void
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
0
>(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
__global__
void
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
>(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
__device__
unsigned
char
dQuantize
<
0
>(
float
*
smem_code
,
const
float
rand
,
float
x
);
template
__device__
unsigned
char
dQuantize
<
0
>(
float
*
smem_code
,
const
float
rand
,
float
x
);
template
__device__
unsigned
char
dQuantize
<
1
>(
float
*
smem_code
,
const
float
rand
,
float
x
);
template
__device__
unsigned
char
dQuantize
<
1
>(
float
*
smem_code
,
const
float
rand
,
float
x
);
...
...
csrc/kernels.cuh
View file @
c771b3a7
...
@@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
...
@@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
__global__
void
kHistogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
const
int
maxidx1
,
const
int
n
);
template
<
typename
T
,
int
SPMM_ITEMS
,
int
BITS
>
__global__
void
kspmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
template
<
typename
T
,
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kgetColRowStats
(
T
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kDoubleRowColQuant
(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
TRANSPOSE
,
int
FORMAT
>
__global__
void
kTransformRowToFormat
(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
#endif
#endif
csrc/ops.cu
View file @
c771b3a7
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <cub/device/device_scan.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <limits>
#include <BinSearch.h>
#include <BinSearch.h>
#include <cassert>
#include <common.h>
#include <common.h>
...
@@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
...
@@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
void
gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
)
{
const
int
falpha
=
1
;
const
int
fbeta
=
0
;
const
void
*
alpha
=
&
falpha
;
const
void
*
beta
=
&
fbeta
;
cublasStatus_t
status
;
status
=
cublasGemmEx
(
context
->
m_handle
,
transposeA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
transposeB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_8I
,
lda
,
B
,
CUDA_R_8I
,
ldb
,
beta
,
C
,
CUDA_R_32I
,
ldc
,
CUDA_R_32I
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
std
::
cout
<<
"CUBLAS ERROR: Status "
<<
status
<<
std
::
endl
;
}
}
void
strided_gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
,
long
long
int
strideA
,
long
long
int
strideB
,
long
long
int
strideC
,
int
batchCount
)
{
const
int
falpha
=
1
;
const
int
fbeta
=
0
;
const
void
*
alpha
=
&
falpha
;
const
void
*
beta
=
&
fbeta
;
cublasStatus_t
status
;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
status
=
cublasGemmStridedBatchedEx
(
context
->
m_handle
,
transposeA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
transposeB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_8I
,
lda
,
(
long
long
int
)
strideA
,
B
,
CUDA_R_8I
,
ldb
,
(
long
long
int
)
strideB
,
beta
,
C
,
CUDA_R_32I
,
ldc
,
(
long
long
int
)
strideC
,
batchCount
,
CUDA_R_32I
,
CUBLAS_GEMM_DEFAULT
);
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
std
::
cout
<<
"CUBLAS ERROR: Status "
<<
status
<<
std
::
endl
;
}
}
int
roundoff
(
int
v
,
int
d
)
{
return
(
v
+
d
-
1
)
/
d
*
d
;
}
template
<
int
ORDER
>
cublasLtOrder_t
get_order
()
{
switch
(
ORDER
)
{
case
ROW
:
return
CUBLASLT_ORDER_ROW
;
break
;
case
COL
:
return
CUBLASLT_ORDER_COL
;
break
;
case
COL32
:
return
CUBLASLT_ORDER_COL32
;
break
;
case
COL_TURING
:
return
CUBLASLT_ORDER_COL4_4R2_8C
;
break
;
case
COL_AMPERE
:
return
CUBLASLT_ORDER_COL32_2R_4R4
;
break
;
}
}
template
cublasLtOrder_t
get_order
<
ROW
>();
template
cublasLtOrder_t
get_order
<
COL
>();
template
cublasLtOrder_t
get_order
<
COL32
>();
template
cublasLtOrder_t
get_order
<
COL_TURING
>();
template
cublasLtOrder_t
get_order
<
COL_AMPERE
>();
template
<
int
ORDER
>
int
get_leading_dim
(
int
dim1
,
int
dim2
)
{
switch
(
ORDER
)
{
case
ROW
:
return
dim2
;
break
;
case
COL
:
return
dim1
;
break
;
case
COL32
:
// 32*row tiles
return
dim1
*
32
;
break
;
case
COL_TURING
:
return
32
*
roundoff
(
dim1
,
8
);
break
;
case
COL_AMPERE
:
// 32*32 tiles
return
32
*
roundoff
(
dim1
,
32
);
break
;
}
}
template
int
get_leading_dim
<
ROW
>(
int
dim1
,
int
dim2
);
template
int
get_leading_dim
<
COL
>(
int
dim1
,
int
dim2
);
template
int
get_leading_dim
<
COL32
>(
int
dim1
,
int
dim2
);
template
<
typename
T
,
int
SRC
,
int
TARGET
,
bool
transpose
,
int
DTYPE
>
void
transform
(
cublasLtHandle_t
ltHandle
,
T
*
A
,
T
*
out
,
int
dim1
,
int
dim2
)
{
cublasLtOrder_t
orderA
=
get_order
<
SRC
>
();
cublasLtOrder_t
orderOut
=
get_order
<
TARGET
>
();
int
ldA
=
get_leading_dim
<
SRC
>
(
dim1
,
dim2
);
int
ldOut
=
get_leading_dim
<
TARGET
>
(
dim1
,
dim2
);
cublasLtMatrixLayout_t
A_desc
=
NULL
,
out_desc
=
NULL
;
cublasLtMatrixTransformDesc_t
A2Out_desc
=
NULL
;
cublasOperation_t
opTranspose
=
CUBLAS_OP_T
;
float
transformAlpha
=
1.0
f
,
transformBeta
=
0.0
f
;
if
(
DTYPE
==
8
)
{
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
A_desc
,
CUDA_R_8I
,
dim1
,
dim2
,
ldA
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
out_desc
,
CUDA_R_8I
,
dim1
,
dim2
,
ldOut
));
}
else
if
(
DTYPE
==
32
)
{
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
A_desc
,
CUDA_R_32I
,
dim1
,
dim2
,
ldA
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
out_desc
,
CUDA_R_32I
,
dim1
,
dim2
,
ldOut
));
}
else
{
printf
(
"ERROR WRONG TYPE FOR TRANSFORM: %i
\n
"
,
DTYPE
);
}
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
A_desc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
orderA
,
sizeof
(
orderA
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
out_desc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
orderOut
,
sizeof
(
orderOut
)));
checkCublasStatus
(
cublasLtMatrixTransformDescCreate
(
&
A2Out_desc
,
CUDA_R_32F
));
if
(
transpose
){
checkCublasStatus
(
cublasLtMatrixTransformDescSetAttribute
(
A2Out_desc
,
CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA
,
&
opTranspose
,
sizeof
(
opTranspose
)));
}
checkCublasStatus
(
cublasLtMatrixTransform
(
ltHandle
,
A2Out_desc
,
&
transformAlpha
,
A
,
A_desc
,
&
transformBeta
,
NULL
,
NULL
,
out
,
out_desc
,
0
));
if
(
A_desc
)
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
A_desc
));
if
(
out_desc
)
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
out_desc
));
if
(
A2Out_desc
)
checkCublasStatus
(
cublasLtMatrixTransformDescDestroy
(
A2Out_desc
));
}
template
void
transform
<
int8_t
,
ROW
,
COL
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
ROW
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
COL32
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int32_t
,
ROW
,
COL32
,
false
,
32
>(
cublasLtHandle_t
ltHandle
,
int32_t
*
A
,
int32_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
COL_TURING
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
ROW
,
COL_AMPERE
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int8_t
,
COL32
,
ROW
,
false
,
8
>(
cublasLtHandle_t
ltHandle
,
int8_t
*
A
,
int8_t
*
out
,
int
dim1
,
int
dim2
);
template
void
transform
<
int32_t
,
COL32
,
ROW
,
false
,
32
>(
cublasLtHandle_t
ltHandle
,
int32_t
*
A
,
int32_t
*
out
,
int
dim1
,
int
dim2
);
template
<
int
FORMATB
,
int
DTYPE_OUT
,
int
SCALE_ROWS
>
int
igemmlt
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
int
has_error
=
0
;
cublasLtMatmulDesc_t
matmulDesc
=
NULL
;
cublasLtMatrixLayout_t
Adesc
=
NULL
,
Bdesc
=
NULL
,
Cdesc
=
NULL
;
cublasOperation_t
opT
=
CUBLAS_OP_T
;
cublasLtPointerMode_t
alphaVec
=
CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO
;
cublasLtOrder_t
col32
=
CUBLASLT_ORDER_COL32
;
cublasLtOrder_t
col_turing
=
CUBLASLT_ORDER_COL4_4R2_8C
;
cublasLtOrder_t
col_ampere
=
CUBLASLT_ORDER_COL32_2R_4R4
;
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Adesc
,
CUDA_R_8I
,
m
,
k
,
lda
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Bdesc
,
CUDA_R_8I
,
n
,
k
,
ldb
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Adesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col32
,
sizeof
(
col32
)));
if
(
FORMATB
==
COL_TURING
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col_turing
,
sizeof
(
col_turing
)));
else
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Bdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col_ampere
,
sizeof
(
col_ampere
)));
if
(
DTYPE_OUT
==
32
)
{
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescCreate
(
&
matmulDesc
,
CUBLAS_COMPUTE_32I
,
CUDA_R_32I
));
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
matmulDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
opT
,
sizeof
(
opT
)));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
CUDA_R_32I
,
m
,
n
,
ldc
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col32
,
sizeof
(
col32
)));
int
alpha
=
1
,
beta
=
0
;
has_error
|=
checkCublasStatus
(
cublasLtMatmul
(
ltHandle
,
matmulDesc
,
&
alpha
,
A
,
Adesc
,
B
,
Bdesc
,
&
beta
,
(
int32_t
*
)
C
,
Cdesc
,
(
int32_t
*
)
C
,
Cdesc
,
NULL
,
NULL
,
0
,
0
));
}
else
{
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescCreate
(
&
matmulDesc
,
CUBLAS_COMPUTE_32I
,
CUDA_R_32F
));
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
matmulDesc
,
CUBLASLT_MATMUL_DESC_TRANSB
,
&
opT
,
sizeof
(
opT
)));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
Cdesc
,
CUDA_R_8I
,
m
,
n
,
ldc
));
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
Cdesc
,
CUBLASLT_MATRIX_LAYOUT_ORDER
,
&
col32
,
sizeof
(
col32
)));
if
(
!
SCALE_ROWS
)
{
float
alpha
=
1.0
f
,
beta
=
0.0
f
;
has_error
|=
checkCublasStatus
(
cublasLtMatmul
(
ltHandle
,
matmulDesc
,
&
alpha
,
A
,
Adesc
,
B
,
Bdesc
,
&
beta
,
(
int8_t
*
)
C
,
Cdesc
,
(
int8_t
*
)
C
,
Cdesc
,
NULL
,
NULL
,
0
,
0
));
}
else
{
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
matmulDesc
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
alphaVec
,
sizeof
(
alphaVec
)));
has_error
|=
checkCublasStatus
(
cublasLtMatmul
(
ltHandle
,
matmulDesc
,
row_scale
,
A
,
Adesc
,
B
,
Bdesc
,
NULL
,
(
int8_t
*
)
C
,
Cdesc
,
(
int8_t
*
)
C
,
Cdesc
,
NULL
,
NULL
,
0
,
0
));
}
}
if
(
Cdesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
Cdesc
));
if
(
Bdesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
Bdesc
));
if
(
Adesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatrixLayoutDestroy
(
Adesc
));
if
(
matmulDesc
)
has_error
|=
checkCublasStatus
(
cublasLtMatmulDescDestroy
(
matmulDesc
));
if
(
has_error
==
1
)
printf
(
"error detected"
);
return
has_error
;
}
int
fill_up_to_nearest_multiple
(
int
value
,
int
multiple
)
{
return
value
+
(
value
%
multiple
==
0
?
0
:
(
multiple
-
(
value
%
multiple
)));
}
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
)
{
int
threads
=
512
;
int
tileCols
=
fill_up_to_nearest_multiple
(
numCols
,
32
);
int
n
=
numRows
*
tileCols
;
int
subtile_rows
=
128
;
int
tilesize
=
32
*
subtile_rows
;
int
num_blocks
=
numRows
/
subtile_rows
;
num_blocks
+=
(
numRows
%
subtile_rows
==
0
)
?
0
:
1
;
num_blocks
=
num_blocks
*
(
tileCols
/
32
);
assert
(
threads
<=
tilesize
);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16
<
4
,
128
,
512
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
,
tileCols
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
#define STATS_THREADS 64
#define STATS_ITEMS 4
#define STATS_ROWS 16
void
getColRowStats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
)
{
int
tile_cols
=
STATS_THREADS
*
STATS_ITEMS
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
STATS_ROWS
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
STATS_ROWS
);
if
(
nnz_threshold
==
0.0
)
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
0
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
else
if
(
nnz_threshold
!=
0.0
)
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
1
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
void
doubleRowColQuant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
)
{
int
threads
=
64
;
int
items_per_thread
=
4
;
int
tile_cols
=
threads
*
items_per_thread
;
int
tile_rows
=
16
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
tile_rows
);
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
if
(
threshold
>
0.0
f
)
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_block_ptr
,
threshold
,
rows
,
cols
,
tiledCols
);
else
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
0
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_block_ptr
,
threshold
,
rows
,
cols
,
tiledCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
template
<
int
FORMAT
,
int
TRANSPOSE
>
void
transformRowToFormat
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
int
threads
=
256
;
int
items_per_thread
=
8
;
// we load 128 column values per warp
int
tile_cols
=
32
*
items_per_thread
;
int
tile_rows
=
32
;
int
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
tile_cols
);
int
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
tile_rows
);
int
num_blocks
=
(
tiledCols
/
tile_cols
)
*
(
tiledRows
/
tile_rows
);
int
outCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
outRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
if
(
FORMAT
==
COL_TURING
)
{
if
(
TRANSPOSE
)
outRows
=
fill_up_to_nearest_multiple
(
cols
,
8
);
else
outRows
=
fill_up_to_nearest_multiple
(
rows
,
8
);
}
else
if
(
FORMAT
==
COL_AMPERE
)
{
if
(
TRANSPOSE
)
outRows
=
fill_up_to_nearest_multiple
(
cols
,
32
);
else
outRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
}
else
{
if
(
TRANSPOSE
)
{
outCols
=
fill_up_to_nearest_multiple
(
rows
,
32
);
outRows
=
cols
;
}
}
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
TRANSPOSE
,
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
out
,
rows
,
cols
,
tiledCols
,
outRows
,
outCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
void
spmm_coo
(
cusparseHandle_t
handle
,
int
*
A_rowidx
,
int
*
A_colidx
,
half
*
A_vals
,
int
A_nnz
,
int
A_rows
,
int
A_cols
,
int
B_cols
,
int
ldb
,
half
*
B
,
int
ldc
,
half
*
C
,
bool
transposed_B
)
{
cusparseSpMatDescr_t
descA
;
cusparseDnMatDescr_t
descB
,
descC
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
void
*
dBuffer
=
NULL
;
size_t
bufferSize
=
0
;
CHECK_CUSPARSE
(
cusparseCreateCoo
(
&
descA
,
A_rows
,
A_cols
,
A_nnz
,
A_rowidx
,
A_colidx
,
A_vals
,
CUSPARSE_INDEX_32I
,
CUSPARSE_INDEX_BASE_ZERO
,
CUDA_R_16F
)
);
// Create dense matrix C
CHECK_CUSPARSE
(
cusparseCreateDnMat
(
&
descC
,
A_rows
,
B_cols
,
ldc
,
C
,
CUDA_R_16F
,
CUSPARSE_ORDER_ROW
)
);
// Create dense matrix B
if
(
transposed_B
)
{
int
tmp
=
A_cols
;
A_cols
=
B_cols
;
B_cols
=
tmp
;
}
CHECK_CUSPARSE
(
cusparseCreateDnMat
(
&
descB
,
A_cols
,
B_cols
,
ldb
,
B
,
CUDA_R_16F
,
CUSPARSE_ORDER_ROW
)
);
// allocate an external buffer if needed
CHECK_CUSPARSE
(
cusparseSpMM_bufferSize
(
handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
transposed_B
?
CUSPARSE_OPERATION_TRANSPOSE
:
CUSPARSE_OPERATION_NON_TRANSPOSE
,
&
alpha
,
descA
,
descB
,
&
beta
,
descC
,
CUDA_R_32F
,
CUSPARSE_SPMM_ALG_DEFAULT
,
&
bufferSize
)
);
CUDA_CHECK_RETURN
(
cudaMalloc
(
&
dBuffer
,
bufferSize
)
);
// execute SpMM
CHECK_CUSPARSE
(
cusparseSpMM
(
handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
transposed_B
?
CUSPARSE_OPERATION_TRANSPOSE
:
CUSPARSE_OPERATION_NON_TRANSPOSE
,
&
alpha
,
descA
,
descB
,
&
beta
,
descC
,
CUDA_R_32F
,
CUSPARSE_SPMM_ALG_DEFAULT
,
dBuffer
));
// destroy matrix/vector descriptors
CHECK_CUSPARSE
(
cusparseDestroySpMat
(
descA
)
);
CHECK_CUSPARSE
(
cusparseDestroyDnMat
(
descB
)
);
CHECK_CUSPARSE
(
cusparseDestroyDnMat
(
descC
)
);
CUDA_CHECK_RETURN
(
cudaFree
(
dBuffer
)
);
}
template
<
typename
T
,
int
BITS
>
void
spmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
kspmm_coo_very_sparse_naive
<
T
,
8
,
BITS
><<<
nnz_rows
,
256
>>>
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz
,
rowsA
,
rowsB
,
colsB
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
template
void
spmm_coo_very_sparse_naive
<
half
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
void
spmm_coo_very_sparse_naive
<
signed
char
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
int
igemmlt
<
COL_TURING
,
32
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_TURING
,
8
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_TURING
,
8
,
1
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_AMPERE
,
32
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_AMPERE
,
8
,
0
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
int
igemmlt
<
COL_AMPERE
,
8
,
1
>(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
transformRowToFormat
<
COL32
,
0
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL32
,
1
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_TURING
,
0
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_TURING
,
1
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_AMPERE
,
0
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
transformRowToFormat
<
COL_AMPERE
,
1
>(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
template
void
estimateQuantiles
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
void
estimateQuantiles
(
float
*
A
,
float
*
code
,
float
offset
,
int
n
);
...
...
csrc/ops.cuh
View file @
c771b3a7
...
@@ -14,6 +14,11 @@
...
@@ -14,6 +14,11 @@
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
#include <vector>
#include <functional>
#define CUDA_CHECK_RETURN(value) { \
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
cudaError_t _m_cudaStat = value; \
...
@@ -25,6 +30,34 @@
...
@@ -25,6 +30,34 @@
#define THREADS_PER_BLOCKS (512)
#define THREADS_PER_BLOCKS (512)
#define CHECK_CUSPARSE(value) { \
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
inline
void
checkCudaStatus
(
cudaError_t
status
)
{
if
(
status
!=
cudaSuccess
)
{
printf
(
"cuda API failed with status %d: %s
\n
"
,
status
,
cudaGetErrorString
(
status
));
throw
std
::
logic_error
(
"cuda API failed"
);
}
}
inline
int
checkCublasStatus
(
cublasStatus_t
status
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
printf
(
"cuBLAS API failed with status %d
\n
"
,
status
);
//throw std::logic_error("cuBLAS API failed");
return
1
;
}
return
0
;
}
typedef
enum
Operations_t
typedef
enum
Operations_t
{
{
ksmul
=
0
,
ksmul
=
0
,
...
@@ -39,6 +72,57 @@ typedef enum Optimizer_t
...
@@ -39,6 +72,57 @@ typedef enum Optimizer_t
ADAGRAD
=
4
,
ADAGRAD
=
4
,
}
Optimizer_t
;
}
Optimizer_t
;
typedef
enum
Transform_t
{
ROW
=
0
,
COL
=
1
,
COL32
=
2
,
COL_TURING
=
3
,
COL_AMPERE
=
4
,
}
Transform_t
;
class
Context
{
public:
cublasHandle_t
m_handle
;
Context
()
{
cublasHandle_t
handle
;
cublasCreate_v2
(
&
handle
);
m_handle
=
handle
;
}
};
class
ContextLt
{
public:
cublasLtHandle_t
m_handle
;
ContextLt
()
{
cublasLtHandle_t
handle
;
cublasLtCreate
(
&
handle
);
m_handle
=
handle
;
}
};
class
ContextCusparse
{
public:
cusparseHandle_t
m_handle
;
ContextCusparse
()
{
cusparseHandle_t
handle
;
cusparseCreate
(
&
handle
);
m_handle
=
handle
;
}
};
template
<
typename
T
>
void
estimateQuantiles
(
T
*
A
,
float
*
code
,
float
offset
,
int
n
);
template
<
typename
T
>
void
estimateQuantiles
(
T
*
A
,
float
*
code
,
float
offset
,
int
n
);
...
@@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
...
@@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void
histogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
);
void
histogramScatterAdd2D
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
);
void
gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
);
void
strided_gemmex
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
,
long
long
int
strideA
,
long
long
int
strideB
,
long
long
int
strideC
,
int
batchCount
);
template
<
int
FORMATB
,
int
DTYPE_OUT
,
int
SCALE_ROWS
>
int
igemmlt
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
SRC
,
int
TARGET
,
bool
transpose
,
int
DTYPE
>
void
transform
(
cublasLtHandle_t
ltHandle
,
T
*
A
,
T
*
out
,
int
dim1
,
int
dim2
);
void
cutlass_igemm
(
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
);
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
);
void
getColRowStats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
);
void
doubleRowColQuant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
);
template
<
int
FORMAT
,
int
TRANSPOSE
>
void
transformRowToFormat
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
);
void
spmm_coo
(
cusparseHandle_t
handle
,
int
*
A_rowidx
,
int
*
A_colidx
,
half
*
A_vals
,
int
A_nnz
,
int
A_rows
,
int
A_cols
,
int
B_cols
,
int
ldb
,
half
*
B
,
int
ldc
,
half
*
C
,
bool
transposed_B
);
template
<
typename
T
,
int
BITS
>
void
spmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
#endif
#endif
csrc/pythonInterface.c
View file @
c771b3a7
...
@@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
...
@@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
void
dequantizeBlockwise_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
int
blocksize
,
const
int
n
){
dequantizeBlockwise
<
float
>
(
code
,
A
,
absmax
,
out
,
blocksize
,
n
);
}
#endif
#endif
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
} \
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col
,
n
,
int8_t
,
ROW
,
COL
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
row
,
n
,
int8_t
,
ROW
,
ROW
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col32
,
n
,
int8_t
,
ROW
,
COL32
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
32
,
row
,
col32
,
n
,
int32_t
,
ROW
,
COL32
,
false
,
32
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col_turing
,
n
,
int8_t
,
ROW
,
COL_TURING
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
row
,
col_ampere
,
n
,
int8_t
,
ROW
,
COL_AMPERE
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
8
,
col32
,
row
,
n
,
int8_t
,
COL32
,
ROW
,
false
,
8
);
MAKE_FUNC_TRANSFORM
(
32
,
col32
,
row
,
n
,
int32_t
,
COL32
,
ROW
,
false
,
32
);
void
transform_row2col32
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL32
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2col32T
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL32
,
1
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2turing
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_TURING
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2turingT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_TURING
,
1
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2ampere
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_AMPERE
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2ampereT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_AMPERE
,
1
>
(
A
,
out
,
rows
,
cols
);
}
int
igemmlt_turing_32
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
32
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_turing_8
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
8
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_turing_8_rowscale
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
8
,
1
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_ampere_32
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_AMPERE
,
32
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_ampere_8
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_AMPERE
,
8
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
igemmlt_ampere_8_rowscale
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_AMPERE
,
8
,
1
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
void
spmm_coo_very_sparse_naive_fp16
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive
<
half
,
16
>
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
void
spmm_coo_very_sparse_naive_int8
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive
<
signed
char
,
8
>
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
extern
"C"
extern
"C"
{
{
#if BUILD_CUDA
#if BUILD_CUDA
...
@@ -155,7 +201,86 @@ extern "C"
...
@@ -155,7 +201,86 @@ extern "C"
void
cpercentile_clipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g16
(
g
,
gnorm_vec
,
step
,
n
);
}
void
cpercentile_clipping_g16
(
half
*
g
,
float
*
gnorm_vec
,
int
step
,
const
int
n
){
percentileClipping_g16
(
g
,
gnorm_vec
,
step
,
n
);
}
void
chistogram_scatter_add_2d
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
){
histogramScatterAdd2D
(
histogram
,
index1
,
index2
,
src
,
maxidx1
,
n
);
}
void
chistogram_scatter_add_2d
(
float
*
histogram
,
int
*
index1
,
int
*
index2
,
float
*
src
,
int
maxidx1
,
int
n
){
histogramScatterAdd2D
(
histogram
,
index1
,
index2
,
src
,
maxidx1
,
n
);
}
#endif
void
cigemm
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemmex
(
context
,
transposeA
,
transposeB
,
m
,
n
,
k
,
A
,
B
,
C
,
lda
,
ldb
,
ldc
);
}
void
cbatched_igemm
(
Context
*
context
,
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
,
long
strideA
,
long
strideB
,
long
strideC
,
int
batchCount
)
{
strided_gemmex
(
context
,
transposeA
,
transposeB
,
m
,
n
,
k
,
A
,
B
,
C
,
lda
,
ldb
,
ldc
,
strideA
,
strideB
,
strideC
,
batchCount
);
}
Context
*
get_context
(){
return
new
Context
();
}
ContextCusparse
*
get_cusparse
(){
return
new
ContextCusparse
();
}
int
cigemmlt_turing_32
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_turing_32
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
//{ (cublasLtHandle_t)context->m_handle; return 0; }
//{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int
cigemmlt_turing_8
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_turing_8
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_turing_8_rowscale
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_turing_8_rowscale
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_ampere_32
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_ampere_32
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_ampere_8_rowscale
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_ampere_8_rowscale
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
int
cigemmlt_ampere_8
(
Context
*
context
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt_ampere_8_rowscale
((
cublasLtHandle_t
)
context
->
m_handle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
} \
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col
,
n
,
int8_t
,
ROW
,
COL
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
row
,
n
,
int8_t
,
ROW
,
ROW
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col32
,
n
,
int8_t
,
ROW
,
COL32
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
32
,
row
,
col32
,
n
,
int32_t
,
ROW
,
COL32
,
false
,
32
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col_turing
,
n
,
int8_t
,
ROW
,
COL_TURING
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
row
,
col_ampere
,
n
,
int8_t
,
ROW
,
COL_AMPERE
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
col32
,
row
,
n
,
int8_t
,
COL32
,
ROW
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
32
,
col32
,
row
,
n
,
int32_t
,
COL32
,
ROW
,
false
,
32
)
void
cdequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
)
{
dequant_mm_int32_fp16
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
);
}
void
cget_col_row_stats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
)
{
getColRowStats
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
);
}
void
cdouble_rowcol_quant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_row_ptr
,
float
threshold
,
int
rows
,
int
cols
)
{
doubleRowColQuant
(
A
,
rowStats
,
colStats
,
out_col_normed
,
out_row_normed
,
rowidx
,
colidx
,
val
,
nnz_row_ptr
,
threshold
,
rows
,
cols
);
}
void
ctransform_row2col32
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2col32
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2col32T
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2col32T
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2turing
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2turing
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2turingT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2turingT
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2ampere
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2ampere
(
A
,
out
,
rows
,
cols
);
}
void
ctransform_row2ampereT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
)
{
transform_row2ampereT
(
A
,
out
,
rows
,
cols
);
}
void
cspmm_coo
(
ContextCusparse
*
context
,
int
*
A_rowidx
,
int
*
A_colidx
,
half
*
A_vals
,
int
A_nnz
,
int
A_rows
,
int
A_cols
,
int
B_cols
,
int
ldb
,
half
*
B
,
int
ldc
,
half
*
C
,
bool
transposed_B
)
{
spmm_coo
((
cusparseHandle_t
)
context
->
m_handle
,
A_rowidx
,
A_colidx
,
A_vals
,
A_nnz
,
A_rows
,
A_cols
,
B_cols
,
ldb
,
B
,
ldc
,
C
,
transposed_B
);
}
void
cspmm_coo_very_sparse_naive_fp16
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive_fp16
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
void
cspmm_coo_very_sparse_naive_int8
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive_int8
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
const
int
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
const
int
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
}
}
...
...
tests/test_autograd.py
0 → 100644
View file @
c771b3a7
import
pytest
import
torch
import
bitsandbytes
as
bnb
from
itertools
import
product
n
=
1
k
=
25
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
funcs
=
[(
torch
.
bmm
,
bnb
.
bmm_cublas
),
(
torch
.
matmul
,
bnb
.
matmul_cublas
)]
str_funcs
=
[
'bmm'
,
'matmul'
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'FF'
,
'TF'
,
'TT'
,
'FT'
]
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
True
),
(
True
,
False
)]
str_transpose
=
[
'FF'
,
'FT'
,
'TT'
,
'TF'
]
dtype
=
[
torch
.
float32
,
torch
.
float16
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose"
,
values
,
ids
=
names
)
def
test_matmul
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
)
elif
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
.
t
())
elif
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
.
t
(),
B
)
out_bnb
=
funcs
[
1
](
A
.
t
(),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
.
t
(),
B
.
t
())
out_bnb
=
funcs
[
1
](
A
.
t
(),
B
.
t
())
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
# batched matrix multiply
if
funcs
[
0
]
in
[
torch
.
bmm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
B
=
torch
.
randn
(
size
=
(
dim1
,
dim3
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
)
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.01
torch
.
testing
.
assert_allclose
(
out_bnb
,
out_torch
,
atol
=
0.027
,
rtol
=
0.2
)
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
if
funcs
[
0
]
in
[
torch
.
matmul
]:
dim1
=
dim1
-
(
dim1
%
16
)
A
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
])
dimB
=
(
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim3
,
dim4
)
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
target
=
torch
.
randn
(
size
=
(
dim1
,
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
])
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
if
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B
.
t
())
else
:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B
)
n
=
out_bnb
.
numel
()
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
n
=
1
k
=
3
dim1
=
torch
.
randint
(
16
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
96
,
size
=
(
n
,)).
tolist
()
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
#dim4 = (23,)
decomp
=
[
0.0
,
6.0
]
funcs
=
[(
torch
.
matmul
,
bnb
.
matmul
)]
str_funcs
=
[
'matmul'
]
req_grad
=
[(
False
,
False
),
(
True
,
False
),
(
True
,
True
),
(
False
,
True
)]
req_grad_str
=
[
'FF'
,
'TF'
,
'TT'
,
'FT'
]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
'NT'
,
'NN'
]
dtype
=
[
torch
.
float16
]
has_fp16_weights
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
decomp
,
has_fp16_weights
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights"
,
values
,
ids
=
names
)
def
test_matmullt
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
decomp
,
has_fp16_weights
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
outlier_dim
=
torch
.
randint
(
0
,
dimA
[
1
],
size
=
(
dimA
[
1
]
//
8
,),
device
=
'cuda'
)
for
i
in
range
(
k
):
# normal multiply
if
funcs
[
0
]
in
[
torch
.
mm
,
torch
.
matmul
]:
A
=
torch
.
randn
(
size
=
dimA
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
0
],
dtype
=
dtype
)
if
decomp
==
6.0
:
with
torch
.
no_grad
():
A
[:,
outlier_dim
]
=
6.0
B
=
torch
.
randn
(
size
=
dimB
,
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
target
=
torch
.
randn
(
size
=
(
dim2
,
dim4
),
device
=
'cuda'
,
requires_grad
=
req_grad
[
1
],
dtype
=
dtype
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
=
B
.
clone
()
state
=
bnb
.
MatmulLtState
()
state
.
threshold
=
decomp
state
.
has_fp16_weights
=
has_fp16_weights
if
not
has_fp16_weights
:
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
B2
=
B2
.
t
().
contiguous
()
state
.
CB
,
CBt
,
state
.
SCB
,
SCBt
,
coo_tensorB
=
bnb
.
functional
.
double_quant
(
B2
)
B2
=
state
.
CB
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
,
state
=
state
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
state
=
state
)
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
mean
().
item
()
#print(f'abs error {err:.4f}')
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.01
,
rtol
=
0.1
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.0175
idx
=
torch
.
isclose
(
out_bnb
,
out_torch
,
atol
=
0.035
,
rtol
=
0.2
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.001
if
has_fp16_weights
:
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
torch
.
cuda
.
synchronize
()
loss_bnb
=
torch
.
nn
.
functional
.
mse_loss
(
out_bnb
,
target
).
mean
()
loss_bnb
.
backward
()
gradA1
=
A
.
grad
gradB1
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
loss_torch
=
torch
.
nn
.
functional
.
mse_loss
(
out_torch
,
target
).
mean
()
loss_torch
.
backward
()
gradA2
=
A
.
grad
gradB2
=
B
.
grad
A
.
grad
=
None
B
.
grad
=
None
if
req_grad
[
0
]:
torch
.
testing
.
assert_allclose
(
gradA1
,
gradA2
,
atol
=
0.015
,
rtol
=
0.1
)
if
req_grad
[
1
]:
n
=
gradB1
.
numel
()
assert
torch
.
abs
(
gradB1
).
sum
()
>
0.0
assert
torch
.
abs
(
gradB2
).
sum
()
>
0.0
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.06
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.1
idx
=
torch
.
isclose
(
gradB1
,
gradB2
,
atol
=
0.10
,
rtol
=
0.3
)
assert
(
idx
==
0
).
sum
().
item
()
<
n
*
0.02
torch
.
testing
.
assert_allclose
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
tests/test_functional.py
View file @
c771b3a7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
pytest
import
pytest
import
math
import
random
import
time
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
einops
from
itertools
import
product
from
itertools
import
product
from
bitsandbytes
import
functional
as
F
from
bitsandbytes
import
functional
as
F
torch
.
set_printoptions
(
precision
=
4
,
sci_mode
=
False
,
linewidth
=
120
,
edgeitems
=
20
,
threshold
=
10000
)
k
=
20
def
assert_all_approx_close
(
a
,
b
,
rtol
,
atol
,
count
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
print
(
f
'Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
FFN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_features
,
hidden_size
,
bias
=
True
):
super
(
FFN
,
self
).
__init__
()
self
.
fc1
=
torch
.
nn
.
Linear
(
input_features
,
hidden_size
,
bias
=
bias
)
self
.
fc2
=
torch
.
nn
.
Linear
(
hidden_size
,
input_features
,
bias
=
bias
)
with
torch
.
no_grad
():
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
fc1
.
weight
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
fc2
.
weight
)
def
forward
(
self
,
x
):
x
=
torch
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
Timer
(
object
):
def
__init__
(
self
):
self
.
starts
=
{}
self
.
ends
=
{}
self
.
agg
=
{}
def
tick
(
self
,
name
=
'default'
):
if
name
not
in
self
.
starts
:
self
.
starts
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
ends
[
name
]
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
self
.
starts
[
name
].
record
()
else
:
ms
=
self
.
tock
(
name
,
evict
=
True
,
print_ms
=
False
)
def
tock
(
self
,
name
=
'default'
,
evict
=
True
,
print_ms
=
True
):
if
name
in
self
.
ends
:
self
.
ends
[
name
].
record
()
torch
.
cuda
.
synchronize
()
ms
=
self
.
starts
[
name
].
elapsed_time
(
self
.
ends
[
name
])
if
name
not
in
self
.
agg
:
self
.
agg
[
name
]
=
0.0
self
.
agg
[
name
]
+=
ms
if
evict
:
self
.
starts
.
pop
(
name
)
self
.
ends
.
pop
(
name
)
if
print_ms
and
name
in
self
.
agg
:
print
(
'{0} took: {1:.5f}s'
.
format
(
name
,
self
.
agg
[
name
]
/
1000.0
))
return
self
.
agg
[
name
]
def
reset
(
self
):
self
.
starts
=
{}
self
.
ends
=
{}
self
.
agg
=
{}
print
(
'Resetting benchmark data'
)
def
setup
():
def
setup
():
pass
pass
...
@@ -64,8 +125,8 @@ def test_dynamic_quantization():
...
@@ -64,8 +125,8 @@ def test_dynamic_quantization():
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diff
.
mean
().
item
()
<
0.0135
assert
diff
.
mean
().
item
()
<
0.0135
print
(
sum
(
diffs
)
/
len
(
diffs
))
#
print(sum(diffs)/len(diffs))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
#
print(sum(reldiffs)/len(reldiffs))
for
i
in
range
(
100
):
for
i
in
range
(
100
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
)
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
)
...
@@ -88,8 +149,8 @@ def test_dynamic_blockwise_quantization():
...
@@ -88,8 +149,8 @@ def test_dynamic_blockwise_quantization():
diffs
.
append
(
diff
.
mean
().
item
())
diffs
.
append
(
diff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
reldiffs
.
append
(
reldiff
.
mean
().
item
())
assert
diffs
[
-
1
]
<
0.011
assert
diffs
[
-
1
]
<
0.011
print
(
sum
(
diffs
)
/
len
(
diffs
))
#
print(sum(diffs)/len(diffs))
print
(
sum
(
reldiffs
)
/
len
(
reldiffs
))
#
print(sum(reldiffs)/len(reldiffs))
diffs
=
[]
diffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -125,7 +186,7 @@ def test_percentile_clipping(gtype):
...
@@ -125,7 +186,7 @@ def test_percentile_clipping(gtype):
n
=
4
n
=
4
step
=
0
step
=
0
percentile
=
5
percentile
=
5
for
i
in
range
(
1000
):
for
i
in
range
(
k
):
step
+=
1
step
+=
1
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
'cuda'
)
g
=
torch
.
randn
(
n
,
n
,
dtype
=
gtype
,
device
=
'cuda'
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
gnorm1
,
clip2
,
gnorm_scale
=
F
.
percentile_clipping
(
g
,
gnorm_vec2
,
step
,
percentile
=
percentile
)
...
@@ -145,69 +206,1653 @@ def test_percentile_clipping(gtype):
...
@@ -145,69 +206,1653 @@ def test_percentile_clipping(gtype):
torch
.
testing
.
assert_allclose
(
gnorm1
,
gnorm2
)
torch
.
testing
.
assert_allclose
(
gnorm1
,
gnorm2
)
def
quant
(
x
):
max1
=
torch
.
abs
(
x
).
max
()
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
dequant
(
c
,
maxC
):
return
c
.
float
()
*
(
maxC
/
127
)
def
mm_dequant
(
maxA
,
maxB
,
C
):
return
C
.
float
()
*
(
maxA
/
127
)
*
(
maxB
/
127
)
def
quant_multi
(
x
,
dim
):
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
quant_multi_chunk
(
x
,
dim
,
chunk_size
=
32
):
if
dim
==
1
:
x_chunked
=
einops
.
rearrange
(
x
,
'(c a) b -> c a b'
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
+
1
,
keepdim
=
True
)
max1
=
torch
.
tile
(
max1
,
(
1
,
1
,
x
.
shape
[
1
]))
max1
=
max1
.
view
(
x
.
shape
)
elif
dim
==
0
:
x_chunked
=
einops
.
rearrange
(
x
,
'a (b c) -> a b c'
,
c
=
chunk_size
)
max1
=
torch
.
amax
(
torch
.
abs
(
x_chunked
),
dim
=
dim
,
keepdim
=
True
)
max1
=
torch
.
tile
(
max1
,
(
x
.
shape
[
0
],
1
,
1
))
max1
=
max1
.
view
(
x
.
shape
)
max1
[
max1
==
0
]
=
1.0
x
=
torch
.
round
(
x
/
max1
*
127
)
return
max1
,
x
.
to
(
torch
.
int8
)
def
quant_minmax
(
A
):
minA
=
A
.
min
()
maxA
=
A
.
max
()
def
mean
(
xx
):
return
sum
(
xx
)
/
float
(
len
(
xx
))
#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
dim1
=
[
1024
*
2
]
dim2
=
[
1024
*
16
]
methods
=
[(
lambda
x
,
dim
:
quant
(
x
),
lambda
x
,
dim
:
quant
(
x
),
dequant
,
dequant
,
mm_dequant
)]
methods
.
append
((
quant_multi
,
quant_multi
,
dequant
,
dequant
,
mm_dequant
))
#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
method_names
=
[
'linear'
,
'vectorwise'
]
batched
=
[
False
,
True
]
values
=
list
(
product
(
dim1
,
dim2
,
methods
,
batched
))
values_names
=
list
(
product
(
dim1
,
dim2
,
method_names
,
batched
))
names
=
[
'dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'
.
format
(
*
vals
)
for
vals
in
values_names
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, quant_methods, batched"
,
values
,
ids
=
names
)
def
test_approx_igemm
(
dim1
,
dim2
,
quant_methods
,
batched
):
dim1
=
dim1
-
(
dim1
%
32
)
dim2
=
dim2
-
(
dim2
%
32
)
errors
=
[]
relerrors
=
[]
print
(
''
)
for
i
in
range
(
5
):
if
batched
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim1
,
dim2
//
32
),
device
=
'cuda'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
32
,
dim2
//
32
,
dim1
),
device
=
'cuda'
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
2
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
1
)
else
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
)
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim2
,
dim1
),
device
=
'cuda'
)
maxA
,
Ac
=
quant_methods
[
0
](
A
,
1
)
maxB
,
Bc
=
quant_methods
[
1
](
B
,
0
)
torch
.
testing
.
assert_allclose
(
quant_methods
[
2
](
maxA
,
Ac
),
A
,
atol
=
0.025
,
rtol
=
0.05
)
if
batched
:
out2
=
torch
.
bmm
(
A
,
B
)
C
=
torch
.
bmm
(
Ac
.
float
(),
Bc
.
float
())
else
:
out2
=
torch
.
mm
(
A
,
B
)
C
=
F
.
igemm
(
Ac
,
Bc
)
out
=
quant_methods
[
4
](
maxA
,
maxB
,
C
)
std
=
out2
.
std
()
out
/=
std
out2
/=
std
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
torch
.
abs
(
out2
)
errors
.
append
(
err
.
mean
().
item
())
relerrors
.
append
(
relerr
.
mean
().
item
())
print
(
mean
(
errors
))
print
(
mean
(
relerrors
))
def
test_stable_embedding
():
def
test_stable_embedding
():
layer
=
bnb
.
nn
.
StableEmbedding
(
1024
,
1024
)
layer
=
bnb
.
nn
.
StableEmbedding
(
1024
,
1024
)
layer
.
reset_parameters
()
layer
.
reset_parameters
()
def
test_dynamic_blockwise_quantization_cpu
():
#A1 = torch.randn(1024, 1024, device='cpu')
#code = F.create_dynamic_map()
#for i in range(1000):
# C, S = F.quantize_blockwise(A1, code=code)
# A2 = F.dequantize_blockwise(C, S)
for
i
in
range
(
10
):
n
=
2
# equivalence with GPU blockwise quantization
hidden_dim
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cpu'
)
batch_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
C1
,
S1
=
F
.
quantize_blockwise
(
A1
)
seq_dim
=
torch
.
randint
(
16
,
256
,
size
=
(
n
,)).
tolist
()
C2
,
S2
=
F
.
quantize_blockwise
(
A1
.
cuda
())
transpose
=
[(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
)]
torch
.
testing
.
assert_allclose
(
S1
[
0
],
S2
[
0
].
cpu
())
values
=
list
(
product
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
))
# there seems to be some issues with precision in CUDA vs CPU
names
=
[
'hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
# not all elements are usually close, with couple off elements in a million
@
pytest
.
mark
.
parametrize
(
"hidden_dim, batch_dim, transpose, seq_dim"
,
values
,
ids
=
names
)
idx
=
torch
.
isclose
(
C1
,
C2
.
cpu
())
def
test_igemm
(
hidden_dim
,
batch_dim
,
transpose
,
seq_dim
):
assert
(
idx
==
0
).
sum
().
item
()
<
15
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
16
)
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
for
i
in
range
(
k
):
shapeA
=
(
batch_dim
,
hidden_dim
)
if
not
transpose
[
0
]
else
(
hidden_dim
,
batch_dim
)
shapeB
=
((
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
)))
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
elif
not
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
out
=
F
.
igemm
(
A
,
B
.
t
())
elif
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
t
().
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
.
t
(),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
t
().
float
(),
B
.
t
().
float
())
out
=
F
.
igemm
(
A
.
t
(),
B
.
t
())
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
diffs
=
[]
for
i
in
range
(
k
):
reldiffs
=
[]
shapeA
=
(
batch_dim
,
seq_dim
,
hidden_dim
)
shapeB
=
((
32
*
random
.
randint
(
1
,
4
),
hidden_dim
)
if
transpose
[
1
]
else
(
hidden_dim
,
32
*
random
.
randint
(
1
,
4
)))
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
elif
not
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
out
=
F
.
igemm
(
A
,
B
.
t
())
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
n
=
3
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
))
names
=
[
'seq_dim{0}_hidden_dim{1}_batch_dim{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim"
,
values
,
ids
=
names
)
def
test_dim3_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
):
seq_dim
=
seq_dim
-
(
seq_dim
%
32
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
32
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
for
i
in
range
(
25
):
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
batch_dim
,
seq_dim
,
1024
),
device
=
'cuda'
).
to
(
torch
.
int8
)
out2
=
torch
.
einsum
(
'bsi, bso->io'
,
A
.
float
(),
B
.
float
())
iout
=
torch
.
empty
(
A
.
shape
[
2
],
B
.
shape
[
2
],
dtype
=
torch
.
int32
,
device
=
A
.
device
)
out
=
F
.
igemm
(
A
,
B
,
out
=
iout
)
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
)
n
=
2
seq_dim
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
hidden_dim
=
torch
.
randint
(
32
,
1024
*
4
,
size
=
(
n
,)).
tolist
()
batch_dim
=
torch
.
randint
(
2
,
16
,
size
=
(
n
,)).
tolist
()
transpose
=
[
False
,
True
]
values
=
list
(
product
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
))
names
=
[
'seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"seq_dim, hidden_dim, batch_dim, transpose"
,
values
,
ids
=
names
)
def
test_minmax_igemm
(
seq_dim
,
hidden_dim
,
batch_dim
,
transpose
):
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
2
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
2
,
keepdim
=
True
)
scale
=
(
maxA
-
minA
)
/
2.0
return
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
),
minA
,
scale
seq_dim
=
seq_dim
-
(
seq_dim
%
16
)
hidden_dim
=
hidden_dim
-
(
hidden_dim
%
16
)
batch_dim
=
batch_dim
-
(
batch_dim
%
2
)
errs
=
[]
relerrs
=
[]
errs2
=
[]
relerrs2
=
[]
for
i
in
range
(
k
):
A
=
torch
.
normal
(
0.0
,
0.5
,
size
=
(
batch_dim
,
seq_dim
,
hidden_dim
),
device
=
'cuda'
)
if
transpose
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
256
,
hidden_dim
),
device
=
'cuda'
)
else
:
B
=
torch
.
normal
(
0
,
0.5
,
size
=
(
hidden_dim
,
256
),
device
=
'cuda'
)
Ac
,
minA
,
scale
=
min_max
(
A
)
if
transpose
:
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
(
1
if
transpose
else
0
))
out
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out2
=
torch
.
matmul
(
A
,
B
.
t
())
offset
=
B
.
t
().
sum
(
0
)
*
(
minA
+
scale
)
out
=
out
.
float
()
out
=
(
out
*
maxB
.
t
()
*
scale
/
(
127
*
127
))
+
offset
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
out3
=
F
.
igemm
(
Ac
,
Bc
.
t
())
out3
=
mm_dequant
(
maxA
,
maxB
.
t
(),
out3
)
else
:
maxB
,
Bc
=
quant_multi
(
B
,
dim
=
0
)
offset
=
B
.
sum
(
0
)
*
(
minA
+
scale
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out2
=
torch
.
matmul
(
A
,
B
)
out
=
out
.
float
()
out
=
(
out
*
maxB
*
scale
/
(
127
*
127
))
+
offset
maxA
,
Ac
=
quant_multi
(
A
,
dim
=
2
)
out3
=
F
.
igemm
(
Ac
,
Bc
)
out3
=
mm_dequant
(
maxA
,
maxB
,
out3
)
std
=
out2
.
std
()
out2
/=
std
out
/=
std
out3
/=
std
err
=
torch
.
abs
(
out
-
out2
)
relerr
=
err
/
(
torch
.
abs
(
out2
)
+
1e-7
)
err2
=
torch
.
abs
(
out3
-
out2
)
relerr2
=
err2
/
(
torch
.
abs
(
out2
)
+
1e-7
)
errs
.
append
(
err
.
mean
().
item
())
relerrs
.
append
(
relerr
.
mean
().
item
())
errs2
.
append
(
err2
.
mean
().
item
())
relerrs2
.
append
(
relerr2
.
mean
().
item
())
#print(mean(errs))
#print(mean(relerrs))
#print(mean(errs2))
#print(mean(relerrs2))
assert
mean
(
errs
)
<
0.015
assert
mean
(
relerrs
)
<
0.3
n
=
2
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
transpose
=
[(
False
,
False
),
(
True
,
False
),
(
False
,
True
),
(
True
,
True
)]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, transpose"
,
values
,
ids
=
names
)
def
test_ibmm
(
dim1
,
dim2
,
dim3
,
dim4
,
transpose
):
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
dim4
=
dim4
-
(
dim4
%
16
)
for
i
in
range
(
k
):
shapeA
=
(
dim1
,
dim3
,
dim2
)
if
transpose
[
0
]
else
(
dim1
,
dim2
,
dim3
)
shapeB
=
(
dim1
,
dim4
,
dim3
)
if
transpose
[
1
]
else
(
dim1
,
dim3
,
dim4
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeB
,
device
=
'cuda'
).
to
(
torch
.
int8
)
if
not
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
,
B
)
elif
not
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
())
out
=
F
.
igemm
(
A
,
B
.
permute
([
0
,
2
,
1
]))
elif
transpose
[
0
]
and
not
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
float
())
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
)
elif
transpose
[
0
]
and
transpose
[
1
]:
out2
=
torch
.
bmm
(
A
.
permute
([
0
,
2
,
1
]).
float
(),
B
.
permute
([
0
,
2
,
1
]).
float
())
out
=
F
.
igemm
(
A
.
permute
([
0
,
2
,
1
]),
B
.
permute
([
0
,
2
,
1
]))
torch
.
testing
.
assert_allclose
(
out
.
float
(),
out2
.
float
())
n
=
1
dim1
=
torch
.
randint
(
1
,
64
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
128
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
256
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
,
dim3
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3"
,
values
,
ids
=
names
)
def
test_vector_quant
(
dim1
,
dim2
,
dim3
):
dim2
=
dim2
-
(
dim2
%
16
)
dim3
=
dim3
-
(
dim3
%
16
)
for
i
in
range
(
k
):
A
=
torch
.
randn
(
size
=
(
dim2
,
dim3
),
device
=
'cuda'
)
qA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
A1
=
F
.
vectorwise_dequant
(
qA
,
SA
)
torch
.
testing
.
assert_allclose
(
A1
,
A
,
atol
=
0.01
,
rtol
=
0.1
)
n
=
2
dim1
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
2
,
256
,
size
=
(
n
,)).
tolist
()
#dim1, dim2 = (256,), (256,)
dtype
=
[
torch
.
int8
,
torch
.
int32
]
a_order
=
[
'row'
]
out_order
=
[
'col'
,
'row'
,
'col32'
]
transpose
=
[
False
]
dims
=
[
2
,
3
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
def
test_nvidia_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
if
dims
==
3
and
out_order
!=
'col32'
:
return
if
dtype
==
torch
.
int32
and
out_order
!=
'col32'
:
return
func
=
F
.
get_transform_func
(
dtype
,
orderA
,
orderOut
,
transpose
)
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
to
(
dtype
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
dtype
)
out
,
S
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
if
orderOut
==
'row'
:
torch
.
testing
.
assert_allclose
(
A
.
flatten
(),
out
.
flatten
())
elif
orderOut
==
'col'
:
torch
.
testing
.
assert_allclose
(
A
.
t
().
flatten
(),
out
.
flatten
())
elif
orderOut
==
'col32'
:
if
dims
==
2
:
n
=
A
.
shape
[
0
]
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
elif
dims
==
3
:
n
=
A
.
shape
[
0
]
*
A
.
shape
[
1
]
*
(
A
.
shape
[
2
]
+
(
32
-
(
A
.
shape
[
2
]
%
32
)))
assert
out
.
numel
()
==
n
elif
orderOut
==
'col_turing'
:
# 32 col 8 row tiles
n
=
(
A
.
shape
[
0
]
+
(
8
-
A
.
shape
[
0
]
%
8
))
*
(
A
.
shape
[
1
]
+
(
32
-
(
A
.
shape
[
1
]
%
32
)))
assert
out
.
numel
()
==
n
total_coltile
=
(
A
.
shape
[
1
]
//
32
)
+
(
1
if
A
.
shape
[
1
]
%
32
!=
0
else
0
)
for
row
in
range
(
A
.
shape
[
0
]):
for
col
in
range
(
A
.
shape
[
1
]):
i
=
row
*
A
.
shape
[
1
]
j
=
col
coltile
=
(
col
//
32
)
+
(
1
if
col
%
32
!=
0
else
0
)
rowtile
=
((
row
//
8
)
+
(
1
if
row
%
8
!=
0
else
0
))
*
total_coltile
offset
=
32
*
8
*
(
rowtile
+
coltile
)
col2
=
col
%
32
row2
=
(
row
%
8
)
*
32
assert
A
.
flatten
()[
i
+
j
]
==
A
[
row
,
col
]
#assert A.flatten()[i+j] == out.flatten()[row2+col2]
#torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
#torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if
orderOut
==
'col32'
:
out2
,
S
=
F
.
nvidia_transform
(
out
,
from_order
=
orderOut
,
to_order
=
'row'
,
state
=
S
)
torch
.
testing
.
assert_allclose
(
A
,
out2
)
n
=
1
dim1
=
torch
.
randint
(
1
,
256
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
32
,
512
,
size
=
(
n
,)).
tolist
()
dim3
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
32
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [2]
#dim2 = [2]
#dim3 = [2]
#dim4 = [2]
dims
=
(
2
,
3
)
ldb
=
[
0
]
#ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims, ldb"
,
values
,
ids
=
names
)
def
test_igemmlt_int
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
,
ldb
):
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
elif
dims
==
3
:
A
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim4
,
dim3
),
device
=
'cuda'
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
t
().
float
())
A2
,
SA
=
F
.
transform
(
A
,
'col32'
)
B2
,
SB
=
F
.
transform
(
B
,
'col_turing'
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'row'
,
state
=
SC
)
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
# transpose
B
=
torch
.
randint
(
-
128
,
127
,
size
=
(
dim3
,
dim4
),
device
=
'cuda'
).
to
(
torch
.
int8
)
C1
=
torch
.
matmul
(
A
.
float
(),
B
.
float
())
B2t
,
SBt
=
F
.
transform
(
B
,
'col_turing'
,
transpose
=
True
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2t
,
SA
,
SBt
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'row'
,
state
=
SC
)
torch
.
testing
.
assert_allclose
(
C1
,
C3
.
float
())
dim1
=
[
32
]
dim2
=
[
32
]
dim3
=
[
32
]
dim4
=
[
32
]
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, dims"
,
values
,
ids
=
names
)
def
test_igemmlt_half
(
dim1
,
dim2
,
dim3
,
dim4
,
dims
):
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim3
),
device
=
'cuda'
).
half
()
elif
dims
==
3
:
A
=
torch
.
normal
(
0
,
0.5
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
half
()
B
=
torch
.
randn
((
dim4
,
dim3
),
device
=
'cuda'
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
bnb
.
matmul
(
A
,
B
.
t
())
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
])
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
output
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsAt
,
statsBt
)
#print('')
#print(output.flatten()[:10])
#print(C1.flatten()[:10])
#print(C2.flatten()[:10])
#torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
#B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
#C1 = torch.matmul(A.float(), B.float())
#B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
#C2, SC = F.igemmlt(A2, B2t, SA, SBt)
#C3, S = F.transform(C2, 'row', state=SC)
#torch.testing.assert_allclose(C1, C3.float())
batch_size
=
2
seqdim
=
512
#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
values
=
[(
batch_size
,
seqdim
,
4
*
1024
,
3
*
4
*
1024
),(
batch_size
,
seqdim
,
5120
,
3
*
5120
),(
batch_size
,
seqdim
,
12
*
1024
,
4
*
12
*
1024
)]
#values = list(product(batch, seq, model, hidden))
names
=
[
'batch_{0}_seq_{1}_model_{2}_hidden_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_8bit_training
(
batch
,
seq
,
model
,
hidden
):
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'cuda'
).
half
()
grad
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'cuda'
).
half
()
w1
=
torch
.
randint
(
-
128
,
127
,
size
=
(
hidden
,
model
),
device
=
'cuda'
).
half
()
w2
=
torch
.
randint
(
-
128
,
127
,
size
=
(
model
,
hidden
),
device
=
'cuda'
).
half
()
print
(
''
)
#torch.cuda.synchronize()
## warmup
#for i in range(100):
# torch.matmul(A, w1.t())
#torch.cuda.synchronize()
dtype
=
torch
.
int8
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
grad
=
grad
.
view
(
-
1
,
grad
.
shape
[
-
1
]).
contiguous
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
# fc1
#out2 = torch.matmul(out1, w2.t())# fc2
#d1 = torch.matmul(grad, w2) # delta1
#d2 = torch.matmul(d1, w1) # delta2
#grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
#grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch
.
cuda
.
synchronize
()
t16
=
time
.
time
()
-
t0
print
(
t16
)
#torch.cuda.empty_cache()
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
#C32A, SA = F.transform2(CA, 'col32')
## fc1
#out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
#Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
#C32out1, Sout1 = F.transform2(Cout1, 'col32')
#out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
#Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
#C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
#Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
#C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
#C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
#CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
#C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
#CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
#Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
#CTw1, Sw1 = F.transform2(Cw1, formatB)
#CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
#CTw2, Sw2 = F.transform2(Cw2, formatB)
#CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# C32A, SA = F.transform2(CA, 'col32')
# # fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
# #print(coo_tensor.nnz)
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
# #print(w1.t().shape)
# #out1 = out1dn + out1sp
# # fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
# # delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
# # delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
# # grad1
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
# ## grad2
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
#torch.cuda.synchronize()
#t8 = time.time() - t0
#print(t8)
n
=
2
dim1
=
torch
.
randint
(
64
,
256
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
64
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [2*1024]
#dim4 = [2*1024]
#dim1 = [4]
#dim4 = [4]
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
formatB
=
[
'col_turing'
,
'col_ampere'
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
names
=
[
'dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
):
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'cuda'
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'cuda'
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
'col32'
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
formatB
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'row'
,
state
=
SC
)
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
n
=
C1
.
numel
()
p
=
0.06
assert
count
/
n
<
p
,
f
'error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
'
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
())
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
#print(C2)
n
=
2
dim1
=
[
1
*
1024
]
dim2
=
[
1
*
1024
]
#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims
=
(
2
,)
#ldb = list(range(256, 1*1024, 256))
values
=
list
(
product
(
dim1
,
dim2
,
dims
))
names
=
[
'dim1_{0}_dim2_{1}_dims_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dims"
,
values
,
ids
=
names
)
def
test_colrow_absmax
(
dim1
,
dim2
,
dims
):
for
i
in
range
(
k
):
threshold
=
3.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
).
half
()
A_truncated
=
A
.
clone
()
A_truncated
[
torch
.
abs
(
A_truncated
)
>=
3.0
]
=
0.0
if
dims
==
2
:
row_stats1
,
_
=
torch
.
abs
(
A
.
float
()).
max
(
1
)
col_stats1
,
_
=
torch
.
abs
(
A
.
float
()).
max
(
0
)
row_stats1_trunc
,
_
=
torch
.
abs
(
A_truncated
.
float
()).
max
(
1
)
col_stats1_trunc
,
_
=
torch
.
abs
(
A_truncated
.
float
()).
max
(
0
)
else
:
assert
False
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
threshold
)
A_blocked
=
einops
.
rearrange
(
torch
.
abs
(
A
),
'(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size'
,
row_tiles
=
16
,
block_size
=
64
*
4
)
nnz_rows1_counts
=
(
torch
.
abs
(
A_blocked
)
>=
threshold
).
sum
(
3
).
flatten
()
nnz_block_ptr1
=
torch
.
zeros
(
nnz_rows1_counts
.
shape
[
0
]
+
1
,
dtype
=
nnz_rows1_counts
.
dtype
,
device
=
nnz_rows1_counts
.
device
)
nnz_block_ptr1
[
1
:]
=
nnz_rows1_counts
.
cumsum
(
0
)
torch
.
testing
.
assert_allclose
(
col_stats1_trunc
,
col_stats2
)
torch
.
testing
.
assert_allclose
(
row_stats1_trunc
,
row_stats2
)
torch
.
testing
.
assert_allclose
(
nnz_block_ptr1
,
nnz_block_ptr2
)
row_stats2
,
col_stats2
,
nnz_block_ptr2
=
F
.
get_colrow_absmax
(
A
,
threshold
=
0.0
)
torch
.
testing
.
assert_allclose
(
col_stats1
,
col_stats2
)
torch
.
testing
.
assert_allclose
(
row_stats1
,
row_stats2
)
assert
nnz_block_ptr2
is
None
n
=
2
#dim1 = [8*1024]
#dim2 = [4*1024]
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_double_quant
(
dim1
,
dim2
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
).
half
()
out_col1
,
Scol
=
F
.
vectorwise_quant
(
A
,
dim
=
0
)
out_row1
,
Srow
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
# max difference is 1 due to rounding differences
torch
.
testing
.
assert_allclose
(
CA
,
out_row1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_allclose
(
CAt
,
out_col1
,
atol
=
1
,
rtol
=
0
)
n
=
CAt
.
numel
()
num_not_close_rows
=
(
torch
.
isclose
(
CA
,
out_row1
,
atol
=
1
)
==
0
).
sum
().
item
()
num_not_close_cols
=
(
torch
.
isclose
(
CAt
,
out_col1
,
atol
=
1
)
==
0
).
sum
().
item
()
# allow for 1:500 error due to rounding differences
min_error
=
1
/
500
if
num_not_close_cols
>
(
min_error
*
n
):
print
(
f
'Min error exceeded
{
num_not_close_cols
}
elements are different. Error:
{
num_not_close_cols
/
n
:.
4
f
}
'
)
assert
False
if
num_not_close_rows
>
(
min_error
*
n
):
print
(
f
'Min error exceeded
{
num_not_close_rows
}
elements are different. Error:
{
num_not_close_rows
/
n
:.
4
f
}
'
)
assert
False
torch
.
testing
.
assert_allclose
(
Srow
.
flatten
(),
statsA
)
torch
.
testing
.
assert_allclose
(
Scol
.
flatten
(),
statsAt
)
n
=
4
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim1
=
[
6
]
dim4
=
[
4
]
inner
=
[
8
]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
def
test_integrated_igemmlt
(
dim1
,
dim4
,
inner
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'cuda'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'cuda'
).
half
()
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
torch
.
testing
.
assert_allclose
(
maxA
.
flatten
(),
stats1a
)
torch
.
testing
.
assert_allclose
(
maxB
.
flatten
(),
stats2a
)
torch
.
testing
.
assert_allclose
(
C1a
,
A1
,
rtol
=
0
,
atol
=
1
)
torch
.
testing
.
assert_allclose
(
C2a
,
B1
,
rtol
=
0
,
atol
=
1
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'col32'
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
'col_turing'
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
A2
,
SA
=
F
.
nvidia_transform
(
A1
,
'col32'
)
B2
,
SB
=
F
.
nvidia_transform
(
B1
,
'col_turing'
)
C2
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
'row'
,
state
=
SC
)
out3
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out3
).
mean
().
item
()
assert
err2
<=
err1
*
1.01
n
=
6
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim4
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
inner
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
def
test_igemmlt_row_scale
(
dim1
,
dim4
,
inner
):
formatB
=
F
.
get_special_format_str
()
err1
,
err2
,
err3
=
[],
[],
[]
relerr1
,
relerr2
=
[],
[]
scale
=
1
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'cuda'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'cuda'
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
C1
=
torch
.
matmul
(
A
,
B
.
t
())
out1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'linear'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'col32'
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
c
=
10.0
*
inner
*
scale
row_scale
=
torch
.
ones_like
(
maxA
)
/
c
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
C3
,
S
=
F
.
nvidia_transform
(
outC32
,
'row'
,
state
=
SC
)
maxval
=
torch
.
abs
(
C3
).
max
()
if
maxval
==
127
:
scale
=
1.5
else
:
scale
=
maxval
/
120
out3
=
C3
*
maxA
*
absmaxB
*
c
/
(
127
*
127
)
C4
=
torch
.
matmul
(
C1a
.
float
(),
CB
.
float
().
t
())
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
out2
=
F
.
mm_dequant
(
outC32
,
SC
,
stats1a
,
stats2a
)
CA
,
SA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
,
quant_type
=
'vector'
)
CB
,
SB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
'linear'
)
C
=
torch
.
matmul
(
CA
.
float
(),
CB
.
t
().
float
())
out4
=
C
*
SA
*
SB
/
(
127
*
127
)
#out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
#print('='*80)
#print(out1)
#print(out2)
#print(out3)
#print(out1)
#print(out2)
#print(out3)
err1
.
append
(
torch
.
abs
(
out1
-
out2
).
mean
().
item
())
err2
.
append
(
torch
.
abs
(
out1
-
out3
).
mean
().
item
())
err3
.
append
(
torch
.
abs
(
out1
-
out4
).
mean
().
item
())
#assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
print
(
''
)
print
(
sum
(
err1
)
/
len
(
err1
))
print
(
sum
(
err2
)
/
len
(
err2
))
print
(
sum
(
err3
)
/
len
(
err3
))
dim1
=
[
1024
,
2048
]
inner
=
[
12288
*
4
,
4096
*
4
]
dim4
=
[
12288
,
4096
]
values
=
list
(
zip
(
dim1
,
dim4
,
inner
))
names
=
[
'dim1_{0}_dim4_{1}_inner_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, inner"
,
values
,
ids
=
names
)
def
test_row_scale_bench
(
dim1
,
dim4
,
inner
):
err1
,
err2
,
err3
=
[],
[],
[]
relerr1
,
relerr2
=
[],
[]
scale
=
1
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
'cuda'
).
half
()
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
'cuda'
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
# warmpup
for
i
in
range
(
k
):
C1
=
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
C1
=
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
print
(
'16'
,
time
.
time
()
-
t0
)
C1a
,
C1b
,
stats1a
,
stats1b
,
coo_tensor
=
F
.
double_quant
(
A
)
CB
,
absmaxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'linear'
)
A2
,
SA
=
F
.
nvidia_transform
(
C1a
,
'col32'
)
B2
,
SB
=
F
.
nvidia_transform
(
CB
,
formatB
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
c
=
10.0
*
inner
*
scale
row_scale
=
maxA
/
c
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
,
dtype
=
torch
.
int8
,
row_scale
=
row_scale
)
torch
.
cuda
.
synchronize
()
print
(
'row-wise'
,
time
.
time
()
-
t0
)
C2a
,
C2b
,
stats2a
,
stats2b
,
coo_tensor
=
F
.
double_quant
(
B
)
B2
,
SB
=
F
.
nvidia_transform
(
C2a
,
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
k
):
outC32
,
SC
=
F
.
igemmlt
(
A2
,
B2
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
print
(
'vector-wise'
,
time
.
time
()
-
t0
)
n
=
2
dim1
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
2
,
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [8*1024]
#dim2 = [4*1024]
dim3
=
[
0
]
dtype
=
[
torch
.
int8
]
a_order
=
[
'row'
]
out_order
=
[
'col32'
,
'col_turing'
,
'col_ampere'
]
transpose
=
[
False
,
True
]
dims
=
[
2
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
a_order
,
out_order
,
transpose
))
names
=
[
'dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose"
,
values
,
ids
=
names
)
def
test_transform
(
dim1
,
dim2
,
dim3
,
dims
,
dtype
,
orderA
,
orderOut
,
transpose
):
for
i
in
range
(
k
):
if
dims
==
2
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
to
(
dtype
)
elif
dims
==
3
:
A
=
torch
.
randint
(
10
,
99
,
size
=
(
dim1
,
dim2
,
dim3
),
device
=
'cuda'
).
to
(
dtype
)
A
.
view
(
-
1
)[
-
1
]
=
-
1
if
transpose
:
At
=
A
.
t
().
contiguous
()
out1
,
S1
=
F
.
nvidia_transform
(
At
,
to_order
=
orderOut
)
else
:
out1
,
S1
=
F
.
nvidia_transform
(
A
,
to_order
=
orderOut
)
out2
,
S2
=
F
.
transform
(
A
,
to_order
=
orderOut
,
transpose
=
transpose
)
assert
S1
[
0
][
0
]
==
S2
[
0
][
0
]
assert
S1
[
0
][
1
]
==
S2
[
0
][
1
]
#print(out1)
#print(out2)
torch
.
testing
.
assert_allclose
(
out1
,
out2
)
n
=
2
#dim1 = torch.randint(2,1024, size=(n,)).tolist()
#dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1
=
[
1
]
dim2
=
[
33
]
dtype
=
[
torch
.
int8
]
#a_order = ['col_turing', 'col_ampere']
a_order
=
[
'col_turing'
]
out_order
=
[
'row'
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
a_order
,
out_order
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, orderA, orderOut"
,
values
,
ids
=
names
)
def
test_transform_to_row
(
dim1
,
dim2
,
dtype
,
orderA
,
orderOut
):
for
i
in
range
(
1
):
A
=
torch
.
randint
(
-
127
,
127
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
to
(
dtype
)
out2
,
S2
=
F
.
transform
(
A
,
to_order
=
orderA
)
A2
,
S3
=
F
.
transform
(
out2
,
from_order
=
orderA
,
to_order
=
'row'
,
state
=
S2
)
assert
A2
.
shape
[
0
]
==
A
.
shape
[
0
]
assert
A2
.
shape
[
1
]
==
A
.
shape
[
1
]
print
(
''
)
print
(
A
)
print
(
out2
)
print
(
A2
)
#torch.testing.assert_allclose(A, A2)
def
test_overflow
():
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
2
):
a
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
b
=
torch
.
arange
(
5
,
15
).
cuda
().
to
(
torch
.
int8
).
view
(
-
1
,
1
)
Ca
,
Sa
=
F
.
nvidia_transform
(
a
,
'col32'
)
Cb
,
Sb
=
F
.
nvidia_transform
(
b
,
formatB
)
c
=
F
.
igemmlt
(
Ca
,
Cb
,
Sa
,
Sb
,
dtype
=
torch
.
int8
)
c2
=
torch
.
matmul
(
a
.
float
(),
b
.
float
().
t
())
n
=
2
dim1
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
4
*
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [4]
#dim2 = [5]
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_coo_double_quant
(
dim1
,
dim2
):
threshold
=
3.00
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
).
half
()
idx
=
(
torch
.
abs
(
A
)
>=
threshold
)
CA2
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
if
coo_tensor
is
not
None
:
A1
=
A
*
idx
A2
=
torch
.
zeros_like
(
A
)
A2
[
coo_tensor
.
rowidx
.
long
(),
coo_tensor
.
colidx
.
long
()]
=
coo_tensor
.
values
torch
.
testing
.
assert_allclose
(
A1
,
A2
)
A1
=
A
*
(
idx
==
0
)
A2
=
(
CA
.
float
()
*
statsA
.
unsqueeze
(
1
)
/
127
).
half
()
torch
.
testing
.
assert_allclose
(
A
*
(
idx
==
0
),
A2
,
rtol
=
0.05
,
atol
=
1.5e-2
)
n
=
2
dim1
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
1
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
#dim1 = [7]
#dim2 = [11]
transposed_B
=
[
False
,
True
]
values
=
list
(
product
(
dim1
,
dim2
,
transposed_B
))
names
=
[
'dim1_{0}_dim2_{1}_transposed_B_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, transposed_B"
,
values
,
ids
=
names
)
def
test_spmm_coo
(
dim1
,
dim2
,
transposed_B
):
threshold
=
1.5
dim3
=
torch
.
randint
(
32
,
128
,
size
=
(
1
,)).
item
()
#dim3 = 17
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
if
transposed_B
:
B
=
torch
.
randn
(
dim3
,
dim2
).
cuda
().
half
()
else
:
B
=
torch
.
randn
(
dim2
,
dim3
).
cuda
().
half
()
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
if
transposed_B
:
out2
=
F
.
spmm_coo
(
cooA
,
B
.
t
())
out1
=
torch
.
matmul
(
A2
,
B
.
t
())
else
:
out2
=
F
.
spmm_coo
(
cooA
,
B
)
out1
=
torch
.
matmul
(
A2
,
B
)
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
30
)
def
test_spmm_bench
():
batch
=
2
model
=
1024
*
1
hidden
=
model
*
4
seq
=
1024
dim1
=
batch
*
seq
dim2
=
model
dim3
=
hidden
threshold
=
4
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
).
half
()
B
=
torch
.
randn
(
dim2
,
dim3
,
device
=
'cuda'
).
half
()
for
i
in
range
(
10
):
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cpu'
)
C1
=
bnb
.
matmul
(
A
,
B
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
torch
.
cuda
.
synchronize
()
diff
=
torch
.
abs
(
A1
-
A2
)
t0
=
time
.
time
()
reldiff
=
diff
/
torch
.
abs
(
A1
+
1e-8
)
for
i
in
range
(
k
):
diffs
.
append
(
diff
.
mean
().
item
())
C1
=
bnb
.
matmul
(
A
,
B
)
reldiffs
.
append
(
reldiff
.
mean
().
item
())
torch
.
cuda
.
synchronize
()
assert
diffs
[
-
1
]
<
0.011
t8
=
time
.
time
()
-
t0
#print(sum(diffs)/len(diffs))
#print(sum(reldiffs)/len(reldiffs))
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
print
(
nnz
/
idx
.
numel
())
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
diffs
=
[]
for
i
in
range
(
10
):
for
i
in
range
(
10
):
A1
=
torch
.
rand
(
1024
,
1024
,
device
=
'cpu'
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
C
,
S
=
F
.
quantize_blockwise
(
A1
)
A2
=
F
.
dequantize_blockwise
(
C
,
S
)
torch
.
cuda
.
synchronize
()
diff
=
torch
.
abs
(
A1
-
A2
).
mean
().
item
()
t0
=
time
.
time
()
assert
diff
<
0.0033
for
i
in
range
(
k
):
diffs
.
append
(
diff
)
out2
=
F
.
spmm_coo
(
cooA
,
B
)
torch
.
testing
.
assert_allclose
(
A1
,
A2
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
cuda
.
synchronize
()
#print(sum(diffs)/len(diffs))
tsp
=
time
.
time
()
-
t0
print
(
tsp
,
t8
)
print
(
tsp
/
t8
)
n
=
2
dim1
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
dim2
=
torch
.
randint
(
256
,
1
*
1024
,
size
=
(
n
,)).
tolist
()
values
=
list
(
product
(
dim1
,
dim2
))
names
=
[
'dim1_{0}_dim2_{1}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2"
,
values
,
ids
=
names
)
def
test_integrated_sparse_decomp
(
dim1
,
dim2
):
threshold
=
3.0
formatB
=
'col_turing'
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
w1
=
torch
.
randn
(
dim1
,
dim2
).
cuda
().
half
()
out1
=
torch
.
matmul
(
A
,
w1
.
t
())
Cw1
,
Cw1t
,
statsw1
,
statsw1t
,
coo_tensor
=
F
.
double_quant
(
w1
)
CTw1
,
Sw1
=
F
.
transform
(
Cw1
,
formatB
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
)
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out2
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
CA
,
CAt
,
statsA
,
statsAt
,
coo_tensor
=
F
.
double_quant
(
A
,
threshold
=
threshold
)
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
out1_32
,
Sout1_32
=
F
.
igemmlt
(
C32A
,
CTw1
,
SA
,
Sw1
)
out3
=
F
.
mm_dequant
(
out1_32
,
Sout1_32
,
statsA
,
statsw1
)
assert
coo_tensor
is
not
None
out4
=
F
.
spmm_coo
(
coo_tensor
,
w1
.
t
())
out5
=
out3
+
out4
err1
=
torch
.
abs
(
out1
-
out2
).
mean
().
item
()
err2
=
torch
.
abs
(
out1
-
out5
).
mean
().
item
()
assert
err2
<
err1
def
test_matmuls
():
a
=
torch
.
randn
(
256
,
256
).
half
().
cuda
()
b
=
torch
.
randn
(
256
,
256
).
half
().
cuda
()
c1
=
torch
.
matmul
(
a
,
b
)
c2
=
bnb
.
matmul
(
a
,
b
)
c3
=
bnb
.
matmul
(
a
,
b
)
err1
=
torch
.
abs
(
c1
-
c2
).
mean
().
item
()
err2
=
torch
.
abs
(
c1
-
c3
).
mean
().
item
()
assert
err1
<
0.2
assert
err2
<
0.2
n
=
2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
dim2
=
[
12288
]
#dim1 = [32]
#dim2 = [32]
#dtype = [torch.float16, torch.int8]
dtype
=
[
torch
.
float16
]
out_function
=
[
'zeros'
,
'ones'
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
,
out_function
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype, out_func"
,
values
,
ids
=
names
)
def
test_spmm_coo_very_sparse
(
dim1
,
dim2
,
dtype
,
out_func
):
out_func
=
getattr
(
torch
,
out_func
)
threshold
=
3.3
#threshold = 2.8
#threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
).
half
()
if
dtype
==
torch
.
float16
:
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
'cuda'
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
else
:
B
=
torch
.
randn
(
dim2
,
dim2
*
4
,
device
=
'cuda'
).
half
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B
,
SB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'linear'
)
#B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
print
(
''
)
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
out1
=
torch
.
matmul
(
A2
.
half
(),
B
.
half
())
out
=
out_func
(
out1
.
shape
,
dtype
=
torch
.
float16
,
device
=
out1
.
device
)
out1
+=
out
.
clone
()
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
B
,
out
=
out
)
#print(B)
#print(out1)
#print(out2)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
count
=
math
.
ceil
(
p
*
n
)
std
=
out1
.
std
()
out1
/=
std
out2
/=
std
assert_all_approx_close
(
out1
,
out2
.
half
(),
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
#assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col
=
torch
.
randint
(
0
,
A2
.
shape
[
-
1
],
size
=
(
15
,))
#torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
#Bt = torch.randn(dim2*4, dim2, device='cuda').half()
#torch.cuda.synchronize()
#t0 = time.time()
#print(A2.shape, B.shape)
#for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
#torch.cuda.synchronize()
#print(time.time() - t0)
def
test_layout
():
a1
=
torch
.
rand
(
16
,
64
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
a1
=
torch
.
arange
(
16
*
64
,
device
=
'cuda'
).
reshape
(
16
,
64
).
byte
()
a2
,
s2
=
F
.
transform
(
a1
,
'col_turing'
)
print
(
a2
.
shape
)
print
(
a1
.
flatten
()[
8
*
64
:
8
*
64
+
32
])
for
i
in
range
(
4
):
print
(
a2
.
flatten
()[
i
*
8
*
32
:
i
*
8
*
32
+
32
],
0
)
def
test_coo2csr
():
threshold
=
1
A
=
torch
.
randn
(
128
,
128
).
half
().
cuda
()
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
csrA
=
F
.
coo2csr
(
cooA
)
counts
=
csrA
.
rowptr
[
1
:]
-
csrA
.
rowptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
0
]
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
1
))
idx
=
(
A2
!=
0
)
torch
.
testing
.
assert_allclose
(
A2
[
idx
],
csrA
.
values
)
def
test_coo2csc
():
threshold
=
1
A
=
torch
.
randn
(
128
,
128
).
half
().
cuda
()
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
cscA
=
F
.
coo2csc
(
cooA
)
counts
=
cscA
.
colptr
[
1
:]
-
cscA
.
colptr
[:
-
1
]
assert
counts
.
numel
()
==
A
.
shape
[
1
]
torch
.
testing
.
assert_allclose
(
counts
,
(
A2
!=
0
).
sum
(
0
))
# torch uses row-major -> use transpose to transfer to col-major
idx
=
(
A2
.
t
()
!=
0
)
torch
.
testing
.
assert_allclose
(
A2
.
t
()[
idx
],
cscA
.
values
)
n
=
2
#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dim1
=
[
1
*
2048
]
#dim2 = [12288]
dim2
=
[
2048
]
#dim1 = [2]
#dim2 = [2]
dtype
=
[
torch
.
int8
]
values
=
list
(
product
(
dim1
,
dim2
,
dtype
))
names
=
[
'dim1_{0}_dim2_{1}_dtype_{2}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dtype"
,
values
,
ids
=
names
)
def
test_spmm_coo_dequant
(
dim1
,
dim2
,
dtype
):
threshold
=
6.0
#threshold = 2.8
#threshold = 0.0
A
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
).
half
()
B
=
torch
.
empty
(
dim2
,
dim2
*
4
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
Bt
=
B
.
t
().
contiguous
()
CB
,
CBt
,
statsB
,
statsBt
,
coo_tensor
=
F
.
double_quant
(
B
)
rowidx
=
torch
.
randint
(
0
,
A
.
shape
[
-
1
],
size
=
(
15
,))
A
[:,
rowidx
]
=
8.0
idx
=
torch
.
abs
(
A
)
>=
threshold
nnz
=
(
idx
==
1
).
sum
().
item
()
rows
,
cols
=
torch
.
where
(
idx
)
values
=
A
[
idx
]
cooA
=
F
.
COOSparseTensor
(
A
.
shape
[
0
],
A
.
shape
[
1
],
nnz
,
rows
.
int
(),
cols
.
int
(),
values
)
A2
=
A
*
idx
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out1
=
torch
.
matmul
(
A2
,
B
.
half
())
out3
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
.
half
())
out3
=
out3
*
statsBt
.
half
()
/
127
values
,
counts
=
torch
.
unique
(
cooA
.
rowidx
,
return_counts
=
True
)
offset
=
counts
.
cumsum
(
0
).
int
()
max_count
,
max_idx
=
torch
.
sort
(
counts
,
descending
=
True
)
print
(
torch
.
median
(
max_count
.
float
()))
torch
.
testing
.
assert_allclose
(
out2
,
out3
,
rtol
=
0.05
,
atol
=
0.001
)
p
=
200
/
(
2048
*
12288
*
4
)
n
=
out1
.
numel
()
count
=
math
.
ceil
(
p
*
n
)
assert_all_approx_close
(
out1
,
out2
,
rtol
=
0.01
,
atol
=
3.0e-2
,
count
=
count
)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
#torch.cuda.synchronize()
#print('fp16', time.time() - t0)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo
(
cooA
,
B
)
torch
.
cuda
.
synchronize
()
print
(
'cusparse fp16'
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
)
torch
.
cuda
.
synchronize
()
print
(
'int8'
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
torch
.
cuda
.
synchronize
()
print
(
'int8+dequant'
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out2
=
torch
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
print
(
'matmul'
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
out2
=
F
.
spmm_coo_very_sparse
(
cooA
,
CBt
,
dequant_stats
=
statsBt
)
out
=
out1
+
out2
torch
.
cuda
.
synchronize
()
print
(
'sparse+ matmul'
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
torch
.
matmul
(
A
[:,
rowidx
],
Bt
.
t
()[
rowidx
],
out
=
out1
)
torch
.
cuda
.
synchronize
()
print
(
'partial matmul'
,
time
.
time
()
-
t0
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out1
=
bnb
.
matmul
(
A
,
Bt
)
torch
.
cuda
.
synchronize
()
print
(
'partial matmul'
,
time
.
time
()
-
t0
)
batch_size
=
1
seqdim
=
2048
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
'batch_{0}_seq_{1}_model_{2}_hidden_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
formatB
=
F
.
get_special_format_str
()
A
=
torch
.
randn
(
batch
,
seq
,
model
,
device
=
'cuda'
).
half
()
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
'cuda'
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linearMixedBit
.
eval
()
# warmup
for
i
in
range
(
100
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
print
(
''
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
print
(
f
'pytorch: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
bnb
.
matmul
(
A
,
B
)
torch
.
cuda
.
synchronize
()
print
(
f
'bnb lt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
0.0
)
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
CB
,
CBt
,
SCB
,
SCBt
,
coo_tensorB
=
F
.
double_quant
(
B
)
CxB
,
SB
=
F
.
transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
torch
.
cuda
.
synchronize
()
print
(
f
'igemmlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
'col32'
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
'row'
,
state
=
Sout32
)
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
torch
.
cuda
.
synchronize
()
print
(
f
'vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
'linear'
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
A2
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
statsA
=
F
.
vectorwise_quant
(
A2
,
dim
=
1
,
quant_type
=
'linear'
)
C32A
,
SA
=
F
.
nvidia_transform
(
CA
,
'col32'
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
CxB
,
SA
,
SB
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
'row'
,
state
=
Sout32
)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
torch
.
cuda
.
synchronize
()
print
(
f
'linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
'bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
'bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s'
)
def
test_zeropoint
():
def
min_max
(
x
):
maxA
=
torch
.
amax
(
x
,
dim
=
1
,
keepdim
=
True
)
minA
=
torch
.
amin
(
x
,
dim
=
1
,
keepdim
=
True
)
midpoint
=
(
maxA
-
minA
)
/
2.0
dyna
=
252
/
(
maxA
-
minA
)
#dyna *= 0.98
x
=
dyna
*
x
x
=
x
-
torch
.
round
((
dyna
*
(
minA
+
midpoint
)))
return
x
.
to
(
torch
.
int8
),
minA
,
midpoint
,
dyna
batch
=
2
seq
=
2
model
=
4
hidden
=
2
*
model
#batch = 4
#seq = 2048
#model = 1024
#hidden = 8*model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
'cuda'
).
half
()
-
0.4
B
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
model
,
hidden
,
device
=
'cuda'
).
half
())
#A[0] = 0
#B[:, 0] = 0
#A = A*(A>0)
#A[0, 0] = 0
#A[0, 0] = 6.0
Ac
,
minA
,
midpoint
,
dyna
=
min_max
(
A
)
#print(Ac[0, 0], 'zero')
#print(Ac, Ac.min(), Ac.max())
Bc
,
maxB
=
F
.
vectorwise_quant
(
B
,
quant_type
=
'linear'
)
out
=
F
.
igemm
(
Ac
,
Bc
)
out2
=
torch
.
matmul
(
A
,
B
)
offset
=
B
.
sum
(
0
)
*
torch
.
round
(
dyna
*
(
minA
+
midpoint
))
/
dyna
out
=
out
.
float
()
#print(out.shape, maxB.shape, scale.shape, offset.shape)
norm1
=
maxB
/
127
C4
=
(
out
/
dyna
)
*
norm1
+
offset
B1
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B2
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B3
=
torch
.
nn
.
Parameter
(
B
.
clone
())
B4
=
torch
.
nn
.
Parameter
(
B
.
clone
())
C1
=
torch
.
matmul
(
A
,
B1
)
C2
=
bnb
.
matmul_cublas
(
A
,
B2
,
None
,
'linear'
)
C3
=
bnb
.
matmul_cublas
(
A
,
B3
,
None
,
'zeropoint'
)
C4
=
bnb
.
matmul_cublas
(
A
,
B4
,
None
,
'vector-zeropoint'
)
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
print
(
err1
,
err2
,
err3
)
#assert err1 > err2
loss1
=
C1
.
mean
()
loss2
=
C2
.
mean
()
loss3
=
C3
.
mean
()
loss4
=
C4
.
mean
()
loss1
.
backward
()
loss2
.
backward
()
loss3
.
backward
()
loss4
.
backward
()
print
(
B
.
grad
)
print
(
B1
.
grad
)
print
(
B2
.
grad
)
print
(
B3
.
grad
)
print
(
B4
.
grad
)
err1
=
torch
.
abs
(
B1
.
grad
-
B2
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
B1
.
grad
-
B3
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
B1
.
grad
-
B4
.
grad
).
mean
().
item
()
print
(
err1
,
err2
,
err3
)
def
test_zp
():
def
quant_zp
(
x
):
dtype
=
x
.
dtype
x
=
x
.
float
()
dyna
=
x
.
max
()
-
x
.
min
()
if
dyna
==
0
:
dyna
=
1
qx
=
254.
/
dyna
minx
=
x
.
min
()
#zpx = torch.round(minx* qx)
#zpx = 127 - torch.round(x.max()* qx)
zpx
=
torch
.
round
(
x
.
min
()
*
qx
)
-
127
x
=
(
qx
*
x
)
+
zpx
return
x
,
qx
,
zpx
batch
=
2
seq
=
512
model
=
1024
hidden
=
4
*
model
A
=
torch
.
randn
(
batch
*
seq
,
model
,
device
=
'cuda'
).
half
()
*
0.1
B
=
torch
.
randn
(
model
,
hidden
,
device
=
'cuda'
).
half
()
*
0.1
C0
=
torch
.
matmul
(
A
,
B
)
#A, SA = F.vectorwise_quant(A, quant_type='linear')
#B, SB = F.vectorwise_quant(B, quant_type='linear')
A
=
A
.
float
()
B
=
B
.
float
()
C1
=
torch
.
matmul
(
A
,
B
)
C3
=
bnb
.
matmul
(
A
.
half
(),
B
.
t
().
contiguous
().
half
())
zp
=
1
#C2 = torch.matmul(A-zp, B)
#C2 += B.sum(0).view(1, -1)*zp
C2
=
torch
.
matmul
(
A
,
B
-
zp
)
C2
-=
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zp
ca
,
cqa
,
cza
=
quant_zp
(
A
)
print
(
ca
.
min
(),
ca
.
max
())
print
((
ca
-
cza
).
min
(),
(
ca
-
cza
).
max
())
zp
=
1
scale
=
2.0
C5
=
torch
.
matmul
((
A
*
scale
)
-
zp
,
B
)
C5
+=
B
.
sum
(
0
)
*
zp
C5
/=
scale
CA
,
qa
,
zpa
=
quant_zp
(
A
)
C4
=
torch
.
matmul
(
CA
,
B
)
C4
-=
B
.
sum
(
0
)
*
zpa
C4
/=
qa
zpb
=
1
zpa
=
1
qa
=
2
qb
=
2
C6
=
torch
.
matmul
((
A
*
qa
)
+
zpa
,
(
B
*
qb
)
+
zpb
)
C6
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
C6
-=
zpa
*
zpb
*
A
.
shape
[
1
]
C6
/=
qa
*
qb
def
test_histogram
():
CA
,
qa
,
zpa
=
quant_zp
(
A
)
dim1
,
dim2
=
32
,
32
CB
,
qb
,
zpb
=
quant_zp
(
B
)
source
=
torch
.
rand
(
dim1
,
dim2
,
device
=
'cuda'
)
C7
=
torch
.
matmul
(
CA
,
CB
)
idx1
=
torch
.
randint
(
0
,
255
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
int
()
C7
-=
(
qb
*
B
.
sum
(
0
).
view
(
1
,
-
1
)
*
zpa
)
+
(
qa
*
A
.
sum
(
1
).
view
(
-
1
,
1
)
*
zpb
)
idx2
=
torch
.
randint
(
0
,
255
,
size
=
(
dim1
,
dim2
),
device
=
'cuda'
).
int
()
C7
-=
zpa
*
zpb
*
A
.
shape
[
1
]
histogram1
=
torch
.
zeros
((
256
,
256
)).
cuda
()
C7
/=
qa
*
qb
histogram2
=
torch
.
zeros
((
256
,
256
)).
cuda
()
F
.
histogram_scatter_add_2d
(
histogram2
,
idx1
,
idx2
,
source
)
print
(
''
)
#print(C0.flatten()[:10])
print
(
C1
.
flatten
()[:
10
])
print
(
C2
.
flatten
()[:
10
])
print
(
C3
.
flatten
()[:
10
])
print
(
C5
.
flatten
()[:
10
])
print
(
C6
.
flatten
()[:
10
])
print
(
C7
.
flatten
()[:
10
])
err1
=
torch
.
abs
(
C1
-
C2
).
mean
().
item
()
err2
=
torch
.
abs
(
C1
-
C3
).
mean
().
item
()
err3
=
torch
.
abs
(
C1
-
C4
).
mean
().
item
()
err4
=
torch
.
abs
(
C1
-
C5
).
mean
().
item
()
err5
=
torch
.
abs
(
C1
-
C6
).
mean
().
item
()
err6
=
torch
.
abs
(
C1
-
C7
).
mean
().
item
()
print
(
err1
,
err2
,
err3
,
err4
,
err5
,
err6
)
for
i
in
range
(
dim1
):
for
j
in
range
(
dim2
):
histogram1
[
idx1
[
i
,
j
].
item
(),
idx2
[
i
,
j
].
item
()]
+=
source
[
i
,
j
]
torch
.
testing
.
assert_allclose
(
histogram1
,
histogram2
)
torch
.
testing
.
assert_allclose
(
histogram1
.
sum
(),
source
.
sum
())
tests/test_modules.py
View file @
c771b3a7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
pytest
import
pytest
import
torch
import
torch
from
itertools
import
product
from
torch
import
nn
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
class
MockArgs
(
object
):
def
__init__
(
self
,
initial_data
):
for
key
in
initial_data
:
setattr
(
self
,
key
,
initial_data
[
key
])
class
MLP8bit
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim1
,
dim2
,
has_fp16_weights
=
True
,
threshold
=
0.0
):
super
(
MLP8bit
,
self
).
__init__
()
self
.
fc1
=
bnb
.
nn
.
Linear8bitLt
(
dim1
,
dim2
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
self
.
fc2
=
bnb
.
nn
.
Linear8bitLt
(
dim2
,
dim1
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
threshold
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
fc2
(
x
)
return
x
def
get_args
():
args
=
MockArgs
([])
args
.
quant_type
=
'vector'
args
.
use_8bit_training
=
'full'
args
.
clip_freq
=
9999
return
args
def
assert_all_approx_close
(
a
,
b
,
atol
=
1e-8
,
rtol
=
1e-5
,
count
=
10
):
idx
=
torch
.
isclose
(
a
,
b
,
rtol
,
atol
)
sumval
=
(
idx
==
0
).
sum
().
item
()
if
sumval
>
count
:
print
(
f
'Too many values not close: assert
{
sumval
}
<
{
count
}
'
)
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
,
atol
)
class
LinearFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
get_8bit_linear_trimmed
(
x
,
stochastic
=
False
,
trim_value
=
3.0
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
norm
=
math
.
sqrt
(
math
.
pi
)
/
math
.
sqrt
(
2.0
)
#std = torch.abs(x).mean()*norm
std
=
torch
.
std
(
x
)
max1
=
std
*
trim_value
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
x
[
x
>
127
]
=
127
x
[
x
<
-
127
]
=
-
127
x
=
x
/
127
*
max1
return
x
def
quant
(
x
,
quant_type
,
dim
=
1
):
if
quant_type
==
'linear'
:
max1
=
torch
.
abs
(
x
).
max
().
float
()
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'vector'
:
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
xq
=
torch
.
round
(
x
/
max1
*
127
).
to
(
torch
.
int8
)
return
xq
,
max1
elif
quant_type
==
'min-max'
:
maxA
=
torch
.
amax
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
minA
=
torch
.
amin
(
x
,
dim
=
dim
,
keepdim
=
True
).
float
()
scale
=
(
maxA
-
minA
)
/
2.0
xq
=
torch
.
round
(
127
*
(
x
-
minA
-
scale
)
/
scale
).
to
(
torch
.
int8
)
return
xq
,
(
minA
.
float
(),
scale
.
float
())
else
:
return
None
def
dequant
(
xq
,
S1
,
S2
,
dtype
,
quant_type
):
if
quant_type
==
'linear'
:
norm
=
S1
*
S2
/
(
127
*
127
)
# double cast needed to prevent overflows
return
(
xq
.
float
()
*
norm
).
to
(
dtype
)
elif
quant_type
==
'vector'
:
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
S1
.
shape
)
==
3
:
S1
=
S1
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
S2
.
shape
)
==
3
:
S2
=
S2
.
squeeze
(
0
)
#print(x.shape, S1.shape, S2.shape)
if
len
(
S1
.
shape
)
==
2
:
x
*=
S1
.
t
()
/
127
else
:
x
*=
S1
/
127
x
*=
S2
/
127
return
x
.
to
(
dtype
)
else
:
return
None
def
dequant_min_max
(
xq
,
A
,
B
,
SA
,
SB
,
dtype
):
offset
=
B
.
float
().
t
().
sum
(
0
)
*
(
SA
[
0
]
+
SA
[
1
])
x
=
xq
.
float
()
if
len
(
xq
.
shape
)
==
2
and
len
(
SB
.
shape
)
==
3
:
SB
=
SB
.
squeeze
(
0
)
if
len
(
xq
.
shape
)
==
2
and
len
(
SA
.
shape
)
==
3
:
SA
=
SA
.
squeeze
(
0
)
if
len
(
SB
.
shape
)
==
2
:
x
*=
SB
.
t
()
/
127
else
:
x
*=
SB
/
127
x
*=
SA
[
1
]
/
127
x
+=
offset
return
x
.
to
(
dtype
)
def
get_8bit_linear
(
x
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
abs
(
x
).
max
()
x
=
x
/
max1
*
127
x
=
round_func
(
x
)
/
127
*
max1
#x = torch.round(x)/128*max1
return
x
@
staticmethod
def
get_8bit_vector_wise
(
x
,
dim
,
stochastic
=
False
):
round_func
=
LinearFunction
.
round_stoachastic
if
stochastic
else
torch
.
round
max1
=
torch
.
amax
(
torch
.
abs
(
x
),
dim
=
dim
,
keepdim
=
True
)
max1
[
max1
==
0
]
=
1.0
x
=
(
x
*
127
)
/
max1
x
=
round_func
(
x
)
/
127
*
max1
return
x
@
staticmethod
def
round_stoachastic
(
x
):
sign
=
torch
.
sign
(
x
)
absx
=
torch
.
abs
(
x
)
decimal
=
absx
-
torch
.
floor
(
absx
)
rdm
=
torch
.
rand_like
(
decimal
)
return
sign
*
(
torch
.
floor
(
absx
)
+
(
rdm
<
decimal
).
to
(
x
.
dtype
))
@
staticmethod
def
fake_8bit_storage
(
w
,
exponent_bits
):
code
=
bnb
.
functional
.
create_dynamic_map
(
n
=
exponent_bits
).
to
(
w
.
device
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
out
=
out
.
half
()
w
.
copy_
(
out
)
return
out
@
staticmethod
def
fake_8bit_storage_quantile
(
w
,
args
):
code
=
bnb
.
functional
.
estimate_quantiles
(
w
.
data
,
offset
=
args
.
offset
)
#C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out)
#out = out.half()
code
/=
torch
.
max
(
torch
.
abs
(
code
))
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
code
=
code
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
code
)
out
=
out
.
half
()
w
.
copy_
(
out
)
return
out
@
staticmethod
def
fake_8bit_storage_stoachstic
(
w
):
rand
=
torch
.
rand
(
1024
,
device
=
w
.
device
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
w
.
data
,
rand
=
rand
)
out
=
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
)
out
=
out
.
half
()
w
.
copy_
(
out
)
return
out
@
staticmethod
def
fake_8bit_storage_with_max
(
w
,
topk
=
8
):
blocked_w
=
einops
.
rearrange
(
w
.
flatten
(),
'(h b) -> h b'
,
b
=
256
)
max_val
,
idx
=
torch
.
sort
(
torch
.
abs
(
blocked_w
),
dim
=
1
,
descending
=
True
)
idx
=
idx
[:,
:
topk
]
max_val
=
max_val
[:,
:
topk
]
mask
=
torch
.
zeros_like
(
blocked_w
)
mask
.
scatter_
(
dim
=
1
,
index
=
idx
,
src
=
torch
.
ones_like
(
max_val
))
mask
=
mask
.
bool
()
# 1. zero out max values
# 2. quantize + dequantize
# 3. write back max values
# 4. copy matrix back to weight
values
=
blocked_w
[
mask
]
blocked_w
[
mask
]
=
0
code
=
bnb
.
functional
.
create_dynamic_map
()
code
=
code
.
to
(
w
.
device
)
absmax
,
C
=
bnb
.
functional
.
quantize_blockwise
(
blocked_w
.
data
)
bnb
.
functional
.
dequantize_blockwise
(
absmax
,
C
,
out
=
blocked_w
)
blocked_w
[
mask
]
=
values
unblocked_w
=
blocked_w
.
flatten
().
view
(
w
.
shape
)
w
.
copy_
(
unblocked_w
)
return
unblocked_w
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
=
None
,
args
=
None
):
if
args
.
use_8bit_training
!=
'off'
:
weight8
,
S1
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
1
)
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
2
)
outputq
=
bnb
.
functional
.
igemm
(
x8
,
weight8
.
t
())
output
=
LinearFunction
.
dequant
(
outputq
,
S1
,
S2
,
x
.
dtype
,
args
.
quant_type
)
#if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else
:
#output = torch.matmul(x, weight.t())
output
=
torch
.
einsum
(
'bsi,oi->bso'
,
x
,
weight
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
)
ctx
.
args
=
args
if
bias
is
not
None
:
output
+=
bias
.
unsqueeze
(
0
).
expand_as
(
output
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
x
,
weight
,
bias
=
ctx
.
saved_tensors
args
=
ctx
.
args
stochastic
=
False
grad_input
=
grad_weight
=
grad_bias
=
None
if
bias
is
not
None
and
ctx
.
needs_input_grad
[
2
]:
grad_bias
=
grad_output
.
sum
(
0
)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if
args
.
use_8bit_training
==
'forward+wgrad'
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
bnb
.
functional
.
igemm
(
grad_output8
,
x8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input
=
grad_output
.
matmul
(
weight
)
elif
args
.
use_8bit_training
==
'full'
:
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
[
0
,
1
])
x8
,
S2
=
LinearFunction
.
quant
(
x
,
args
.
quant_type
,
dim
=
[
0
,
1
])
grad_weight8
=
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
int32
)
bnb
.
functional
.
igemm
(
grad_output8
,
x8
,
out
=
grad_weight8
)
grad_weight
=
LinearFunction
.
dequant
(
grad_weight8
,
S1
,
S2
,
grad_output
.
dtype
,
args
.
quant_type
)
grad_output8
,
S1
=
LinearFunction
.
quant
(
grad_output
,
args
.
quant_type
,
dim
=
2
)
weight8
,
S3
=
LinearFunction
.
quant
(
weight
,
args
.
quant_type
,
dim
=
0
)
grad_input8
=
bnb
.
functional
.
igemm
(
grad_output8
,
weight8
)
grad_input
=
LinearFunction
.
dequant
(
grad_input8
,
S1
,
S3
,
grad_output
.
dtype
,
args
.
quant_type
)
else
:
grad_input
=
grad_output
.
matmul
(
weight
)
grad_weight
=
torch
.
einsum
(
'bsi,bso->oi'
,
x
,
grad_output
)
@
pytest
.
mark
.
parametrize
(
"embcls"
,
[
bnb
.
nn
.
Embedding
,
bnb
.
nn
.
StableEmbedding
],
ids
=
[
'Embedding'
,
'StableEmbedding'
])
return
grad_input
,
grad_weight
,
grad_bias
,
None
def
test_embeddings
(
embcls
):
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
emb1
=
torch
.
nn
.
Embedding
(
100
,
512
).
cuda
()
emb2
=
embcls
(
100
,
512
).
cuda
()
adam1
=
bnb
.
optim
.
Adam8bit
(
emb1
.
parameters
())
class
Linear8bit
(
nn
.
Module
):
adam2
=
bnb
.
optim
.
Adam8bit
(
emb2
.
parameters
())
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
args
=
None
):
super
(
Linear8bit
,
self
).
__init__
()
self
.
input_features
=
input_features
self
.
output_features
=
output_features
self
.
args
=
args
batches
=
torch
.
randint
(
1
,
100
,
size
=
(
100
,
4
,
32
)).
cuda
()
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
output_features
,
input_features
))
if
bias
:
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
output_features
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
weight
)
if
self
.
bias
is
not
None
:
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
):
self
.
args
.
training
=
self
.
training
return
LinearFunction
.
apply
(
x
,
self
.
weight
,
self
.
bias
,
self
.
args
)
def
test_linear8bit
():
l0
=
torch
.
nn
.
Linear
(
32
,
64
).
cuda
().
half
()
l1
=
bnb
.
nn
.
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l2
=
Linear8bit
(
32
,
64
,
args
=
get_args
()).
cuda
().
half
()
l3
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
).
cuda
().
half
()
l0
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l0
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
l1
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l1
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
l3
.
weight
.
data
=
l2
.
weight
.
data
.
clone
()
l3
.
bias
.
data
=
l2
.
bias
.
data
.
clone
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
t
=
torch
.
randn
(
16
,
8
,
64
,
device
=
'cuda'
).
half
()
b2
=
b1
.
clone
()
b3
=
b1
.
clone
()
b0
=
b1
.
clone
()
o0
=
l0
(
b0
)
o1
=
l1
(
b1
)
o2
=
l2
(
b2
)
o3
=
l3
(
b3
)
assert_all_approx_close
(
o1
,
o2
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
1
)
assert_all_approx_close
(
o3
,
o2
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
1
)
loss0
=
torch
.
nn
.
functional
.
mse_loss
(
o0
,
t
)
loss1
=
torch
.
nn
.
functional
.
mse_loss
(
o1
,
t
)
loss2
=
torch
.
nn
.
functional
.
mse_loss
(
o2
,
t
)
loss3
=
torch
.
nn
.
functional
.
mse_loss
(
o3
,
t
)
loss0
.
backward
()
loss1
.
backward
()
loss2
.
backward
()
loss3
.
backward
()
assert_all_approx_close
(
l1
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l3
.
bias
.
grad
,
l2
.
bias
.
grad
,
atol
=
0.01
,
rtol
=
0
,
count
=
2
)
assert_all_approx_close
(
l1
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
assert_all_approx_close
(
l3
.
weight
.
grad
,
l2
.
weight
.
grad
,
atol
=
0.013
,
rtol
=
0.05
,
count
=
2
)
err1
=
torch
.
abs
(
l0
.
weight
.
grad
-
l1
.
weight
.
grad
).
mean
().
item
()
err2
=
torch
.
abs
(
l0
.
weight
.
grad
-
l2
.
weight
.
grad
).
mean
().
item
()
err3
=
torch
.
abs
(
l0
.
weight
.
grad
-
l3
.
weight
.
grad
).
mean
().
item
()
assert
err1
*
0.8
<
err2
assert
err2
*
0.8
<
err3
assert
err3
*
0.8
<
err1
l0
.
weight
.
grad
=
None
l1
.
weight
.
grad
=
None
l2
.
weight
.
grad
=
None
l3
.
weight
.
grad
=
None
l0
.
bias
.
grad
=
None
l1
.
bias
.
grad
=
None
l2
.
bias
.
grad
=
None
l3
.
bias
.
grad
=
None
threshold
=
[
0.0
,
3.0
]
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_inference
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
).
cuda
().
half
()
assert
l1
.
weight
.
device
.
type
==
'cuda'
assert
l1
.
weight
.
dtype
==
torch
.
float16
l1
.
eval
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
batch
=
batches
[
i
]
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
l1
(
b1
)
if
i
==
1
:
assert
l1
.
state
.
CxB
is
not
None
def
test_linear8bitlt_accumulated_gradient
():
l1
=
torch
.
nn
.
Sequential
(
*
[
bnb
.
nn
.
Linear8bitLt
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
=
torch
.
nn
.
Sequential
(
*
[
torch
.
nn
.
Linear
(
32
,
32
).
cuda
().
half
()
for
i
in
range
(
2
)])
l2
[
0
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
0
].
weight
.
clone
())
l2
[
0
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
0
].
bias
.
clone
())
l2
[
1
].
weight
=
torch
.
nn
.
Parameter
(
l1
[
1
].
weight
.
clone
())
l2
[
1
].
bias
=
torch
.
nn
.
Parameter
(
l1
[
1
].
bias
.
clone
())
opt1
=
bnb
.
optim
.
Adam8bit
(
l1
.
parameters
(),
lr
=
0.001
)
opt2
=
bnb
.
optim
.
Adam8bit
(
l2
.
parameters
(),
lr
=
0.001
)
acc_steps
=
10
embedded1
=
emb1
(
batch
)
for
i
in
range
(
10
):
embedded2
=
emb2
(
batch
)
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
l1
(
b1
)
o2
=
l2
(
b1
)
loss1
=
o1
.
mean
()
loss2
=
o2
.
mean
()
loss1
.
backward
()
loss2
.
backward
()
if
i
==
2
:
assert
l1
[
0
].
state
.
CxB
is
not
None
assert
l1
[
1
].
state
.
CxB
is
not
None
l1
=
embedded1
.
mean
()
if
i
>
0
and
i
%
acc_steps
==
0
:
l2
=
embedded2
.
mean
()
opt1
.
step
()
opt1
.
zero_grad
(
True
)
opt2
.
step
()
opt2
.
zero_grad
(
True
)
assert_all_approx_close
(
l1
[
0
].
weight
,
l2
[
0
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
assert_all_approx_close
(
l1
[
1
].
weight
,
l2
[
1
].
weight
,
rtol
=
1.05
,
atol
=
0.01
,
count
=
2
)
# we do this copy because otherwise we have small divergences over time that add up
l1
[
0
].
weight
.
data
.
copy_
(
l2
[
0
].
weight
.
data
)
l1
[
1
].
weight
.
data
.
copy_
(
l2
[
1
].
weight
.
data
)
else
:
torch
.
testing
.
assert_allclose
(
l1
[
0
].
weight
.
grad
,
l2
[
0
].
weight
.
grad
)
torch
.
testing
.
assert_allclose
(
l1
[
1
].
weight
.
grad
,
l2
[
1
].
weight
.
grad
)
l1
.
backward
()
l2
.
backward
()
adam1
.
step
()
threshold
=
[
0.0
,
2.0
]
adam2
.
step
()
values
=
threshold
names
=
[
'threshold_{0}'
.
format
(
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"threshold"
,
values
,
ids
=
names
)
def
test_linear8bitlt_no_fp16_weights
(
threshold
):
l1
=
bnb
.
nn
.
Linear8bitLt
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
assert
l1
.
weight
.
dtype
==
torch
.
int8
adam1
.
zero_grad
()
l1
.
eval
()
adam2
.
zero_grad
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
l1
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
()
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
adam1
.
state
[
emb1
.
weight
][
'state1'
].
dtype
==
torch
.
uint8
for
i
in
range
(
100
):
assert
adam2
.
state
[
emb2
.
weight
][
'state1'
].
dtype
==
torch
.
float32
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
cuda
().
half
()
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
cuda
()
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
half
().
to
(
'cuda'
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'cuda'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'cuda'
mlp
=
MLP8bit
(
32
,
64
,
threshold
=
threshold
,
has_fp16_weights
=
False
).
to
(
torch
.
float16
).
to
(
'cuda'
)
for
i
in
range
(
100
):
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
'cuda'
).
half
()
o1
=
mlp
(
b1
)
assert
o1
.
dtype
==
torch
.
float16
if
threshold
>
0
:
assert
mlp
.
fc1
.
state
.
idx
is
not
None
if
threshold
>
0
:
assert
mlp
.
fc2
.
state
.
idx
is
not
None
assert
mlp
.
fc1
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc2
.
weight
.
dtype
==
torch
.
int8
assert
mlp
.
fc1
.
weight
.
device
.
type
==
'cuda'
assert
mlp
.
fc2
.
weight
.
device
.
type
==
'cuda'
tests/test_optim.py
View file @
c771b3a7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
os
import
time
import
time
import
shutil
import
shutil
import
uuid
import
uuid
import
pytest
import
pytest
import
ctypes
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
import
bitsandbytes.functional
as
F
import
bitsandbytes.functional
as
F
...
@@ -14,7 +11,9 @@ import bitsandbytes.functional as F
...
@@ -14,7 +11,9 @@ import bitsandbytes.functional as F
from
os.path
import
join
from
os.path
import
join
from
itertools
import
product
from
itertools
import
product
import
apex
#import apex
k
=
20
def
get_temp_dir
():
def
get_temp_dir
():
path
=
'/tmp/autoswap/{0}'
.
format
(
str
(
uuid
.
uuid4
()))
path
=
'/tmp/autoswap/{0}'
.
format
(
str
(
uuid
.
uuid4
()))
...
@@ -26,55 +25,47 @@ def rm_path(path):
...
@@ -26,55 +25,47 @@ def rm_path(path):
str2optimizers
=
{}
str2optimizers
=
{}
str2optimizers
[
'adam_pytorch'
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam_pytorch'
]
=
(
None
,
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam_apex'
]
=
(
None
,
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
#
str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
'momentum_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
#
str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers
[
'momentum_pytorch'
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum_pytorch'
]
=
(
None
,
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
bnb
.
optim
.
Adam
)
str2optimizers
[
'lamb_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.00
,
use_nvlamb
=
True
),
bnb
.
optim
.
Adam
)
#
str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
str2optimizers
[
'lars_apex'
]
=
(
None
,
lambda
pxx
:
apex
.
parallel
.
LARC
.
LARC
(
apex
.
optimizers
.
FusedSGD
(
pxx
,
0.01
,
0.9
)),
bnb
.
optim
.
Adam
)
#
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adam'
]
=
(
torch
.
optim
.
Adam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'adamw'
]
=
(
torch
.
optim
.
AdamW
,
bnb
.
optim
.
AdamW
)
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers
[
'fused_adam'
]
=
(
apex
.
optimizers
.
FusedAdam
,
bnb
.
optim
.
Adam
)
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'momentum'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lars'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lamb'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB
)
#
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'adagrad'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad
(
pxx
,
0.01
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'adam8bit'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'momentum8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'rmsprop8bit'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
False
))
str2optimizers
[
'lamb8bit'
]
=
(
lambda
pxx
:
apex
.
optimizers
.
FusedLAMB
(
pxx
,
weight_decay
=
0.0
,
max_grad_norm
=
10000.0
,
eps
=
1e-8
,
use_nvlamb
=
True
),
bnb
.
optim
.
LAMB8bit
)
#
str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'lars8bit'
]
=
(
lambda
pxx
:
bnb
.
optim
.
PytorchLARS
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
LARS8bit
(
pxx
,
0.01
,
0.9
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'adam8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
Adam8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'adamw8bit_blockwise'
]
=
(
torch
.
optim
.
Adam
,
lambda
pxx
:
bnb
.
optim
.
AdamW8bit
(
pxx
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'momentum8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
SGD
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
SGD8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'rmsprop8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
RMSprop
(
pxx
,
0.01
,
0.9
),
lambda
pxx
:
bnb
.
optim
.
RMSprop8bit
(
pxx
,
0.01
,
0.9
,
block_wise
=
True
))
str2optimizers
[
'adagrad8bit_blockwise'
]
=
(
lambda
pxx
:
torch
.
optim
.
Adagrad
(
pxx
,
0.01
),
lambda
pxx
:
bnb
.
optim
.
Adagrad8bit
(
pxx
,
0.01
,
block_wise
=
True
))
str2statenames
=
{}
str2statenames
=
{}
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'adam'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'adamw'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'momentum'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lars'
]
=
[(
'momentum_buffer'
,
'state1'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'lamb'
]
=
[(
'exp_avg'
,
'state1'
),
(
'exp_avg_sq'
,
'state2'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'rmsprop'
]
=
[(
'square_avg'
,
'state1'
)]
str2statenames
[
'adagrad'
]
=
[(
'sum'
,
'state1'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'lamb8bit'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'max1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'max2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'adam8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'adamw8bit_blockwise'
]
=
[(
'exp_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
),
(
'exp_avg_sq'
,
'state2'
,
'qmap2'
,
'absmax2'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'momentum8bit_blockwise'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'lars8bit'
]
=
[(
'momentum_buffer'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'max1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'rmsprop8bit_blockwise'
]
=
[(
'square_avg'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
str2statenames
[
'adagrad8bit_blockwise'
]
=
[(
'sum'
,
'state1'
,
'qmap1'
,
'absmax1'
)]
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
,
1
]
dim2
=
[
32
,
1024
,
4097
,
1
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam'
,
'adamw'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
,
'adagrad'
]
optimizer_names
=
[
'adam'
,
'momentum'
,
'rmsprop'
,
'lars'
,
'lamb'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
@@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p2
])
if
gtype
==
torch
.
float32
:
if
gtype
==
torch
.
float32
:
atol
,
rtol
=
2
e-6
,
1e-5
atol
,
rtol
=
1
e-6
,
1e-5
else
:
else
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
for
i
in
range
(
50
):
for
i
in
range
(
k
):
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
.
clone
().
float
()
p1
.
grad
=
g
.
clone
().
float
()
p2
.
grad
=
g
.
clone
()
p2
.
grad
=
g
.
clone
()
...
@@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
...
@@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_allclose
(
p1
,
p2
.
float
(),
atol
=
atol
,
rtol
=
rtol
)
if
i
%
10
==
0
and
i
>
0
:
if
i
%
(
k
//
5
)
==
0
and
i
>
0
:
path
=
get_temp_dir
()
path
=
get_temp_dir
()
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'opt.pt'
))
torch
.
save
(
bnb_optimizer
.
state_dict
(),
join
(
path
,
'opt.pt'
))
del
bnb_optimizer
del
bnb_optimizer
...
@@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
eps
=
1e-8
eps
=
1e-8
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
initialize
()
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p2
,
'skip_zeros'
,
True
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'optim_bits'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
override_config
(
p3
,
'optim_bits'
,
8
)
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
bnb
.
optim
.
GlobalOptimManager
.
get_instance
().
register_parameters
([
p1
,
p2
,
p3
])
...
@@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
else
:
else
:
atol
,
rtol
=
1e-4
,
1e-3
atol
,
rtol
=
1e-4
,
1e-3
original_p2
=
p2
[
mask
].
clone
()
for
i
in
range
(
50
):
for
i
in
range
(
50
):
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
g2
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
+
0.001
...
@@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
...
@@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
p2
.
grad
=
g2
p2
.
grad
=
g2
p3
.
grad
=
g3
p3
.
grad
=
g3
if
i
>
30
and
i
%
10
==
0
:
g1
.
data
[
mask
]
=
0.0
g2
.
data
[
mask
]
=
0.0
p1
.
grad
=
g1
p2
.
grad
=
g2
original_p1
=
p1
[
mask
].
clone
()
original_p2
=
p2
[
mask
].
clone
()
og_s1
=
adam2
.
state
[
p2
][
'state1'
][
mask
].
clone
()
og_s2
=
adam2
.
state
[
p2
][
'state2'
][
mask
].
clone
()
og_s11
=
adam2
.
state
[
p1
][
'state1'
][
mask
].
clone
()
og_s21
=
adam2
.
state
[
p1
][
'state2'
][
mask
].
clone
()
adam2
.
step
()
adam2
.
step
()
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state1'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
assert
adam2
.
state
[
p3
][
'state2'
].
dtype
==
torch
.
uint8
if
i
>
30
and
i
%
10
==
0
:
torch
.
testing
.
assert_allclose
(
original_p2
,
p2
[
mask
])
torch
.
testing
.
assert_allclose
(
adam2
.
state
[
p2
][
'state1'
][
mask
],
og_s1
)
torch
.
testing
.
assert_allclose
(
adam2
.
state
[
p2
][
'state2'
][
mask
],
og_s2
)
assert
((
p1
[
mask
]
-
original_p1
)
==
0.0
).
sum
()
<
p1
.
numel
()
assert
((
adam2
.
state
[
p1
][
'state1'
][
mask
]
-
og_s11
)
==
0.0
).
sum
()
==
0.0
assert
((
adam2
.
state
[
p1
][
'state2'
][
mask
]
-
og_s21
)
==
0.0
).
sum
()
==
0.0
dim1
=
[
1024
]
dim1
=
[
1024
]
dim2
=
[
32
,
1024
,
4097
]
dim2
=
[
32
,
1024
,
4097
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
gtype
=
[
torch
.
float32
,
torch
.
float16
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'adamw8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
,
'adagrad8bit_blockwise'
]
optimizer_names
=
[
'adam8bit'
,
'momentum8bit'
,
'rmsprop8bit'
,
'adam8bit_blockwise'
,
'lamb8bit'
,
'lars8bit'
,
'momentum8bit_blockwise'
,
'rmsprop8bit_blockwise'
]
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
values
=
list
(
product
(
dim1
,
dim2
,
gtype
,
optimizer_names
))
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
'dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, gtype, optim_name"
,
values
,
ids
=
names
)
...
@@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...
@@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if
dim1
==
1
and
dim2
==
1
:
return
if
dim1
==
1
and
dim2
==
1
:
return
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
p1
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.1
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
bnb_optimizer
=
str2optimizers
[
optim_name
][
1
]([
p1
])
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
g
=
torch
.
randn
(
dim1
,
dim2
,
device
=
'cuda'
,
dtype
=
gtype
)
*
0.01
p1
.
grad
=
g
p1
.
grad
=
g
for
i
in
range
(
5000
):
for
i
in
range
(
k
):
if
i
==
500
:
if
i
==
k
//
5
:
# 100 iterations for burn-in
# 100 iterations for burn-in
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
...
@@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
s
=
time
.
time
()
-
t0
s
=
time
.
time
()
-
t0
print
(
''
)
print
(
''
)
params
=
4500
*
4096
*
4096
params
=
(
k
-
k
//
5
)
*
dim1
*
dim2
print
(
optim_name
,
gtype
,
s
/
params
)
print
(
optim_name
,
gtype
,
s
/
params
)
#assert s < 3.9
#assert s < 3.9
def
test_str_betas
():
betas
=
(
0.80
,
0.95
)
strbetas
=
'(0.80, 0.95)'
layer
=
torch
.
nn
.
Linear
(
10
,
10
)
base
=
bnb
.
optim
.
Adam
(
layer
.
parameters
(),
betas
=
betas
)
strbase
=
bnb
.
optim
.
Adam
(
layer
.
parameters
(),
betas
=
strbetas
)
assert
base
.
defaults
[
'betas'
][
0
]
==
0.8
assert
base
.
defaults
[
'betas'
][
1
]
==
0.95
assert
strbase
.
defaults
[
'betas'
][
0
]
==
0.8
assert
strbase
.
defaults
[
'betas'
][
1
]
==
0.95
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment