Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
6974920b
Unverified
Commit
6974920b
authored
Feb 01, 2024
by
Aarni Koskela
Committed by
GitHub
Feb 01, 2024
Browse files
Enable line-ending and other hygiene lints (#1006)
parent
3a630c58
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
50 additions
and
71 deletions
+50
-71
csrc/kernels.cu
csrc/kernels.cu
+35
-35
csrc/pythonInterface.c
csrc/pythonInterface.c
+1
-1
docs/source/_toctree.yml
docs/source/_toctree.yml
+2
-2
docs/source/index.mdx
docs/source/index.mdx
+3
-3
docs/source/quickstart.mdx
docs/source/quickstart.mdx
+2
-2
environment.yml
environment.yml
+1
-1
examples/int8_inference_huggingface.py
examples/int8_inference_huggingface.py
+0
-3
how_to_use_nonpytorch_cuda.md
how_to_use_nonpytorch_cuda.md
+1
-1
install_cuda.py
install_cuda.py
+4
-4
scripts/stale.py
scripts/stale.py
+1
-1
tests/test_autograd.py
tests/test_autograd.py
+0
-1
tests/test_cuda_setup_evaluator.py
tests/test_cuda_setup_evaluator.py
+0
-8
tests/test_functional.py
tests/test_functional.py
+0
-2
tests/test_generation.py
tests/test_generation.py
+0
-3
tests/test_modules.py
tests/test_modules.py
+0
-3
tests/test_triton.py
tests/test_triton.py
+0
-1
No files found.
csrc/kernels.cu
View file @
6974920b
...
@@ -110,7 +110,7 @@ __device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
...
@@ -110,7 +110,7 @@ __device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
return
1.00000000
f
*
absmax
*
sign
;
// 1011
return
1.00000000
f
*
absmax
*
sign
;
// 1011
else
else
return
0.66666667
f
*
absmax
*
sign
;
// 1010
return
0.66666667
f
*
absmax
*
sign
;
// 1010
else
else
if
((
val
&
0b0001
)
==
1
)
// 100
if
((
val
&
0b0001
)
==
1
)
// 100
return
5.208333333e-03
f
*
absmax
*
sign
;
// 1001
return
5.208333333e-03
f
*
absmax
*
sign
;
// 1001
else
else
...
@@ -174,36 +174,36 @@ __device__ half dhDequantizeNF4(unsigned char val)
...
@@ -174,36 +174,36 @@ __device__ half dhDequantizeNF4(unsigned char val)
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0001
)
==
1
)
// 111
if
((
val
&
0b0001
)
==
1
)
// 111
return
1.0
f
;
return
1.0
f
;
else
else
return
0.7229568362236023
f
;
return
0.7229568362236023
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 110
if
((
val
&
0b0001
)
==
1
)
// 110
return
0.5626170039176941
f
;
return
0.5626170039176941
f
;
else
else
return
0.44070982933044434
f
;
return
0.44070982933044434
f
;
else
else
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0001
)
==
1
)
// 101
if
((
val
&
0b0001
)
==
1
)
// 101
return
0.33791524171829224
f
;
return
0.33791524171829224
f
;
else
else
return
0.24611230194568634
f
;
return
0.24611230194568634
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 100
if
((
val
&
0b0001
)
==
1
)
// 100
return
0.16093020141124725
f
;
return
0.16093020141124725
f
;
else
else
return
0.07958029955625534
f
;
return
0.07958029955625534
f
;
else
else
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0001
)
==
1
)
// 011
if
((
val
&
0b0001
)
==
1
)
// 011
return
0.0
f
;
return
0.0
f
;
else
else
return
-
0.09105003625154495
f
;
return
-
0.09105003625154495
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 010
if
((
val
&
0b0001
)
==
1
)
// 010
return
-
0.18477343022823334
f
;
return
-
0.18477343022823334
f
;
else
else
return
-
0.28444138169288635
f
;
return
-
0.28444138169288635
f
;
else
else
...
@@ -211,12 +211,12 @@ __device__ half dhDequantizeNF4(unsigned char val)
...
@@ -211,12 +211,12 @@ __device__ half dhDequantizeNF4(unsigned char val)
if
((
val
&
0b0001
)
==
1
)
// 001
if
((
val
&
0b0001
)
==
1
)
// 001
return
-
0.39491748809814453
f
;
return
-
0.39491748809814453
f
;
else
else
return
-
0.5250730514526367
f
;
return
-
0.5250730514526367
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 000
if
((
val
&
0b0001
)
==
1
)
// 000
return
-
0.6961928009986877
f
;
return
-
0.6961928009986877
f
;
else
else
return
-
1.0
f
;
return
-
1.0
f
;
}
}
...
@@ -229,36 +229,36 @@ __device__ float dDequantizeNF4(unsigned char val)
...
@@ -229,36 +229,36 @@ __device__ float dDequantizeNF4(unsigned char val)
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0100
)
==
4
)
// 1
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0010
)
==
2
)
// 11
if
((
val
&
0b0001
)
==
1
)
// 111
if
((
val
&
0b0001
)
==
1
)
// 111
return
1.0
f
;
return
1.0
f
;
else
else
return
0.7229568362236023
f
;
return
0.7229568362236023
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 110
if
((
val
&
0b0001
)
==
1
)
// 110
return
0.5626170039176941
f
;
return
0.5626170039176941
f
;
else
else
return
0.44070982933044434
f
;
return
0.44070982933044434
f
;
else
else
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0010
)
==
2
)
//10
if
((
val
&
0b0001
)
==
1
)
// 101
if
((
val
&
0b0001
)
==
1
)
// 101
return
0.33791524171829224
f
;
return
0.33791524171829224
f
;
else
else
return
0.24611230194568634
f
;
return
0.24611230194568634
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 100
if
((
val
&
0b0001
)
==
1
)
// 100
return
0.16093020141124725
f
;
return
0.16093020141124725
f
;
else
else
return
0.07958029955625534
f
;
return
0.07958029955625534
f
;
else
else
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0100
)
==
4
)
// 0
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0010
)
==
2
)
//01
if
((
val
&
0b0001
)
==
1
)
// 011
if
((
val
&
0b0001
)
==
1
)
// 011
return
0.0
f
;
return
0.0
f
;
else
else
return
-
0.09105003625154495
f
;
return
-
0.09105003625154495
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 010
if
((
val
&
0b0001
)
==
1
)
// 010
return
-
0.18477343022823334
f
;
return
-
0.18477343022823334
f
;
else
else
return
-
0.28444138169288635
f
;
return
-
0.28444138169288635
f
;
else
else
...
@@ -266,12 +266,12 @@ __device__ float dDequantizeNF4(unsigned char val)
...
@@ -266,12 +266,12 @@ __device__ float dDequantizeNF4(unsigned char val)
if
((
val
&
0b0001
)
==
1
)
// 001
if
((
val
&
0b0001
)
==
1
)
// 001
return
-
0.39491748809814453
f
;
return
-
0.39491748809814453
f
;
else
else
return
-
0.5250730514526367
f
;
return
-
0.5250730514526367
f
;
else
else
if
((
val
&
0b0001
)
==
1
)
// 000
if
((
val
&
0b0001
)
==
1
)
// 000
return
-
0.6961928009986877
f
;
return
-
0.6961928009986877
f
;
else
else
return
-
1.0
f
;
return
-
1.0
f
;
}
}
...
@@ -1863,7 +1863,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
...
@@ -1863,7 +1863,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
//float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
//float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
//g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
//g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
g_val
*=
gnorm_scale
;
g_val
*=
gnorm_scale
;
s2_vals
[
j
]
=
(
s2_vals
[
j
]
*
beta2
)
+
(((
1.0
f
-
beta2
)
*
g_val
*
g_val
));
s2_vals
[
j
]
=
(
s2_vals
[
j
]
*
beta2
)
+
(((
1.0
f
-
beta2
)
*
g_val
*
g_val
));
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
s1_vals
[
j
]
=
smem_quantiles1
[
lane_id
][
c1s
[
j
]]
*
absmax1
[
i
/
BLOCK_SIZE
];
...
@@ -3069,7 +3069,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
...
@@ -3069,7 +3069,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// use k warps per thread block
//// use k warps per thread block
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 3. each warp reads a segment of values 16x32 from B
//// 3. each warp reads a segment of values 16x32 from B
//// 4. do dequantization from register of B into second pair of registers
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 6. matmul aggregate into fragment C
...
@@ -3531,7 +3531,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3531,7 +3531,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
,
int
BITS
>
__global__
void
kgemm_4bit_inference_naive
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
const
float
*
datatype
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
// per threadblock:
// per threadblock:
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 4 warps -> 4 loads per iter
// 1x32 * 32x4 -> 1x4 outputs per thread block
// 1x32 * 32x4 -> 1x4 outputs per thread block
...
@@ -3764,7 +3764,7 @@ template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long
...
@@ -3764,7 +3764,7 @@ template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long
{
{
switch
(
FUNC
)
switch
(
FUNC
)
{
{
case
FILL
:
case
FILL
:
A
[
i
]
=
(
T
)
value
;
A
[
i
]
=
(
T
)
value
;
break
;
break
;
case
ARANGE
:
case
ARANGE
:
...
...
csrc/pythonInterface.c
View file @
6974920b
...
@@ -389,7 +389,7 @@ extern "C"
...
@@ -389,7 +389,7 @@ extern "C"
int
hasPrefetch
=
0
;
int
hasPrefetch
=
0
;
CUDA_CHECK_RETURN
(
cudaDeviceGetAttribute
(
&
hasPrefetch
,
cudaDevAttrConcurrentManagedAccess
,
device
));
// 40ns overhead
CUDA_CHECK_RETURN
(
cudaDeviceGetAttribute
(
&
hasPrefetch
,
cudaDevAttrConcurrentManagedAccess
,
device
));
// 40ns overhead
if
(
hasPrefetch
==
0
)
return
;
if
(
hasPrefetch
==
0
)
return
;
CUDA_CHECK_RETURN
(
cudaMemPrefetchAsync
(
ptr
,
bytes
,
device
,
0
));
CUDA_CHECK_RETURN
(
cudaMemPrefetchAsync
(
ptr
,
bytes
,
device
,
0
));
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
...
docs/source/_toctree.yml
View file @
6974920b
-
sections
:
-
sections
:
-
local
:
index
-
local
:
index
title
:
Bits & Bytes
title
:
Bits & Bytes
-
local
:
quickstart
-
local
:
quickstart
title
:
Quickstart
title
:
Quickstart
-
local
:
installation
-
local
:
installation
title
:
Installation
title
:
Installation
title
:
Get started
title
:
Get started
\ No newline at end of file
docs/source/index.mdx
View file @
6974920b
...
@@ -149,10 +149,10 @@ To compile from source, you need an installation of CUDA. If `nvcc` is not insta
...
@@ -149,10 +149,10 @@ To compile from source, you need an installation of CUDA. If `nvcc` is not insta
wget
https
://
raw
.
githubusercontent
.
com
/
TimDettmers
/
bitsandbytes
/
main
/
install_cuda
.
sh
wget
https
://
raw
.
githubusercontent
.
com
/
TimDettmers
/
bitsandbytes
/
main
/
install_cuda
.
sh
#
Syntax
cuda_install
CUDA_VERSION
INSTALL_PREFIX
EXPORT_TO_BASH
#
Syntax
cuda_install
CUDA_VERSION
INSTALL_PREFIX
EXPORT_TO_BASH
#
CUDA_VERSION
in
{
110
,
111
,
112
,
113
,
114
,
115
,
116
,
117
,
118
,
120
,
121
,
122
}
#
CUDA_VERSION
in
{
110
,
111
,
112
,
113
,
114
,
115
,
116
,
117
,
118
,
120
,
121
,
122
}
#
EXPORT_TO_BASH
in
{
0
,
1
}
with
0
=
False
and
1
=
True
#
EXPORT_TO_BASH
in
{
0
,
1
}
with
0
=
False
and
1
=
True
#
For
example
,
the
following
installs
CUDA
11.7
to
~/
local
/
cuda
-
11.7
and
exports
the
path
to
your
.
bashrc
#
For
example
,
the
following
installs
CUDA
11.7
to
~/
local
/
cuda
-
11.7
and
exports
the
path
to
your
.
bashrc
bash
install_cuda
.
sh
117
~/
local
1
bash
install_cuda
.
sh
117
~/
local
1
```
```
To
use
a
specific
CUDA
version
just
for
a
single
compile
run
,
you
can
set
the
variable
`
CUDA_HOME
`,
for
example
the
following
command
compiles
`
libbitsandbytes_cuda117
.
so
`
using
compiler
flags
for
cuda11x
with
the
cuda
version
at
`~/
local
/
cuda
-
11.7
`:
To
use
a
specific
CUDA
version
just
for
a
single
compile
run
,
you
can
set
the
variable
`
CUDA_HOME
`,
for
example
the
following
command
compiles
`
libbitsandbytes_cuda117
.
so
`
using
compiler
flags
for
cuda11x
with
the
cuda
version
at
`~/
local
/
cuda
-
11.7
`:
...
@@ -188,4 +188,4 @@ For 8-bit optimizers or quantization routines, please consider citing the follow
...
@@ -188,4 +188,4 @@ For 8-bit optimizers or quantization routines, please consider citing the follow
journal
={
9
th
International
Conference
on
Learning
Representations
,
ICLR
},
journal
={
9
th
International
Conference
on
Learning
Representations
,
ICLR
},
year
={
2022
}
year
={
2022
}
}
}
```
```
\ No newline at end of file
docs/source/quickstart.mdx
View file @
6974920b
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
## Minimal example
## Minimal example
The following code illustrates the steps above.
The following code illustrates the steps above.
```python
```python
```
```
\ No newline at end of file
environment.yml
View file @
6974920b
...
@@ -42,4 +42,4 @@ dependencies:
...
@@ -42,4 +42,4 @@ dependencies:
## ENV UPDATE:
## ENV UPDATE:
# # add new packages to environment.yml, then:
# # add new packages to environment.yml, then:
# mamba env update -n bnb -f environment.yml
# mamba env update -n bnb -f environment.yml
\ No newline at end of file
examples/int8_inference_huggingface.py
View file @
6974920b
...
@@ -22,6 +22,3 @@ model = AutoModelForCausalLM.from_pretrained(
...
@@ -22,6 +22,3 @@ model = AutoModelForCausalLM.from_pretrained(
)
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
how_to_use_nonpytorch_cuda.md
View file @
6974920b
...
@@ -18,7 +18,7 @@ You can also install CUDA version that you need locally with a script provided b
...
@@ -18,7 +18,7 @@ You can also install CUDA version that you need locally with a script provided b
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122}
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122}
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
...
...
install_cuda.py
View file @
6974920b
...
@@ -49,13 +49,13 @@ def install_cuda(version, base_path, download_path):
...
@@ -49,13 +49,13 @@ def install_cuda(version, base_path, download_path):
# Install CUDA
# Install CUDA
print
(
f
"Installing CUDA version
{
version
}
..."
)
print
(
f
"Installing CUDA version
{
version
}
..."
)
install_command
=
[
install_command
=
[
"bash"
,
filepath
,
"bash"
,
filepath
,
"--no-drm"
,
"--no-man-page"
,
"--override"
,
"--no-drm"
,
"--no-man-page"
,
"--override"
,
"--toolkitpath="
+
install_path
,
"--toolkit"
,
"--silent"
"--toolkitpath="
+
install_path
,
"--toolkit"
,
"--silent"
]
]
print
(
f
"Running command:
{
' '
.
join
(
install_command
)
}
"
)
print
(
f
"Running command:
{
' '
.
join
(
install_command
)
}
"
)
try
:
try
:
subprocess
.
run
(
install_command
,
check
=
True
)
subprocess
.
run
(
install_command
,
check
=
True
)
except
subprocess
.
CalledProcessError
as
e
:
except
subprocess
.
CalledProcessError
as
e
:
...
@@ -99,4 +99,4 @@ def main():
...
@@ -99,4 +99,4 @@ def main():
sys
.
exit
(
1
)
sys
.
exit
(
1
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
scripts/stale.py
View file @
6974920b
...
@@ -55,4 +55,4 @@ def main():
...
@@ -55,4 +55,4 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
tests/test_autograd.py
View file @
6974920b
...
@@ -519,4 +519,3 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
...
@@ -519,4 +519,3 @@ def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
gradB1
,
gradB2
,
atol
=
0.18
,
rtol
=
0.3
)
)
tests/test_cuda_setup_evaluator.py
View file @
6974920b
...
@@ -19,11 +19,3 @@ def test_manual_override(requires_cuda):
...
@@ -19,11 +19,3 @@ def test_manual_override(requires_cuda):
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
loaded_lib
=
bnb
.
cuda_setup
.
main
.
CUDASetup
.
get_instance
().
binary_name
loaded_lib
=
bnb
.
cuda_setup
.
main
.
CUDASetup
.
get_instance
().
binary_name
#assert loaded_lib == 'libbitsandbytes_cuda122.so'
#assert loaded_lib == 'libbitsandbytes_cuda122.so'
tests/test_functional.py
View file @
6974920b
...
@@ -2345,5 +2345,3 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
...
@@ -2345,5 +2345,3 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
torch
.
testing
.
assert_close
(
A
,
C2
)
torch
.
testing
.
assert_close
(
A
,
C2
)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
tests/test_generation.py
View file @
6974920b
...
@@ -120,6 +120,3 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
...
@@ -120,6 +120,3 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
for
out
in
outputs
:
for
out
in
outputs
:
print
(
out
)
print
(
out
)
raise
ValueError
(
f
'Failure count:
{
failure_count
}
/
{
n_cases
}
'
)
raise
ValueError
(
f
'Failure count:
{
failure_count
}
/
{
n_cases
}
'
)
tests/test_modules.py
View file @
6974920b
...
@@ -637,6 +637,3 @@ def test_4bit_warnings():
...
@@ -637,6 +637,3 @@ def test_4bit_warnings():
net
(
inp
)
net
(
inp
)
assert
len
(
record
)
==
2
assert
len
(
record
)
==
2
tests/test_triton.py
View file @
6974920b
...
@@ -58,4 +58,3 @@ def test_switchback(vector_wise_quantization):
...
@@ -58,4 +58,3 @@ def test_switchback(vector_wise_quantization):
print
(
'GX1'
,
err_sb
,
err_baseline
)
print
(
'GX1'
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
assert
err_sb
<
2
*
err_baseline
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment