Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
ba51d95d
Commit
ba51d95d
authored
Jul 11, 2023
by
Tim Dettmers
Browse files
Added more extensive gemv tests; blocksize guard for gemv.
parent
b8da4a16
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
122 additions
and
69 deletions
+122
-69
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+7
-1
bitsandbytes/functional.py
bitsandbytes/functional.py
+1
-0
csrc/kernels.cu
csrc/kernels.cu
+7
-4
csrc/ops.cu
csrc/ops.cu
+1
-0
tests/test_functional.py
tests/test_functional.py
+106
-64
No files found.
bitsandbytes/autograd/_functions.py
View file @
ba51d95d
...
@@ -3,6 +3,7 @@ import warnings
...
@@ -3,6 +3,7 @@ import warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
reduce
# Required in Python 3
from
functools
import
reduce
# Required in Python 3
from
typing
import
Tuple
,
Optional
,
List
from
typing
import
Tuple
,
Optional
,
List
from
warnings
import
warn
import
torch
import
torch
...
@@ -565,6 +566,11 @@ def matmul(
...
@@ -565,6 +566,11 @@ def matmul(
def
matmul_4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
def
matmul_4bit
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
assert
quant_state
is
not
None
if
A
.
numel
()
==
A
.
shape
[
-
1
]
and
A
.
requires_grad
==
False
:
if
A
.
numel
()
==
A
.
shape
[
-
1
]
and
A
.
requires_grad
==
False
:
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
quant_state
if
A
.
shape
[
-
1
]
%
blocksize
!=
0
:
warn
(
f
'Some matrices hidden dimension is not a multiple of
{
blocksize
}
and efficient inference kernels are not supported for these (slow). Matrix input size found:
{
A
.
shape
}
'
)
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
else
:
return
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
return
F
.
gemv_4bit
(
A
,
B
.
t
(),
out
,
state
=
quant_state
)
else
:
else
:
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
return
MatMul4Bit
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
ba51d95d
...
@@ -1504,6 +1504,7 @@ def gemv_4bit(
...
@@ -1504,6 +1504,7 @@ def gemv_4bit(
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_fp32
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
else
:
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
...
...
csrc/kernels.cu
View file @
ba51d95d
...
@@ -222,6 +222,7 @@ __device__ half dhDequantizeNF4(unsigned char val)
...
@@ -222,6 +222,7 @@ __device__ half dhDequantizeNF4(unsigned char val)
__device__
float
dDequantizeNF4
(
unsigned
char
val
)
__device__
float
dDequantizeNF4
(
unsigned
char
val
)
{
{
// the values for this tree was generated by test_normal_map_tree
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
// in the file tests/test_functional.py
if
((
val
&
0b1000
)
==
8
)
if
((
val
&
0b1000
)
==
8
)
...
@@ -3526,10 +3527,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3526,10 +3527,9 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{
{
// per threadblock:
// per threadblock:
// load step-by-step in chunks of [64,warps]: 1x64 * [64,warps] -> [1,warps]
// load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps]
// 64 packed 8bit values per warp AND each warp loads one output for B -> 1x128 * 128x1
// 4 warps -> 4 loads per iter
// 4 warps -> 4 loads per iter
// 1x
128 * 128
x4 -> 1x4 outputs
// 1x
32 * 32
x4 -> 1x4 outputs
per thread block
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
typedef
cub
::
WarpReduce
<
float
>
WarpReduce
;
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
__shared__
typename
WarpReduce
::
TempStorage
temp_storage
[
THREADS
/
32
];
...
@@ -3547,7 +3547,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3547,7 +3547,6 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
T
(
datatype
[
i
]);
quant_map
[
i
]
=
T
(
datatype
[
i
]);
__syncthreads
();
__syncthreads
();
// A: [1, K]
// A: [1, K]
...
@@ -3563,6 +3562,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3563,6 +3562,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{
{
if
((
inner_idx_halved
+
num_values_8bit
)
<
(
K
/
2
))
if
((
inner_idx_halved
+
num_values_8bit
)
<
(
K
/
2
))
{
{
// this is the most important for performance considerations
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
reinterpret_cast
<
int4
(
&
)[
num_values_8bit
]
>
(
local_B_4bit
)[
0
]
=
reinterpret_cast
<
int4
*>
(
B
)[(
offset_B
+
(
inner_idx_halved
))
/
(
num_values_8bit
)];
}
}
else
else
...
@@ -3597,6 +3597,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3597,6 +3597,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
if
(
inner_idx
+
num_values_4bit
<
K
)
if
(
inner_idx
+
num_values_4bit
<
K
)
{
{
// this is also relatively important for performance
if
(
BITS
==
16
)
if
(
BITS
==
16
)
{
{
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
0
];
reinterpret_cast
<
int4
(
&
)[
num_values_4bit
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
int4
*>
(
A
)[
inner_idx
/
(
num_values_4bit
/
4
)
+
0
];
...
@@ -3618,6 +3619,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3618,6 +3619,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
}
}
else
else
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
if
(
inner_idx
+
k
<
K
)
if
(
inner_idx
+
k
<
K
)
local_A
[
k
]
=
A
[
inner_idx
+
k
];
local_A
[
k
]
=
A
[
inner_idx
+
k
];
...
@@ -3625,6 +3627,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
...
@@ -3625,6 +3627,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
local_A
[
k
]
=
T
(
0.0
f
);
local_A
[
k
]
=
T
(
0.0
f
);
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
#pragma unroll
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
for
(
int
k
=
0
;
k
<
num_values_4bit
;
k
++
)
{
{
...
...
csrc/ops.cu
View file @
ba51d95d
...
@@ -735,6 +735,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
...
@@ -735,6 +735,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
int
num_blocks
=
(
m
+
3
)
/
4
;
int
num_blocks
=
(
m
+
3
)
/
4
;
kgemm_4bit_inference_naive
<
T
,
128
,
BITS
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
kgemm_4bit_inference_naive
<
T
,
128
,
BITS
><<<
num_blocks
,
128
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
absmax
,
datatype
,
out
,
lda
,
ldb
,
ldc
,
blocksize
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
template
<
typename
T
,
int
FUNC
>
void
func
(
T
*
A
,
T
*
B
,
T
value
,
long
n
)
...
...
tests/test_functional.py
View file @
ba51d95d
...
@@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype):
...
@@ -2262,7 +2262,7 @@ def test_fp4_quant(dtype):
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
A2
=
F
.
dequantize_fp4
(
qa
,
SA
)
err
=
(
A1
-
A2
).
abs
().
float
()
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
A1
.
abs
().
float
()).
mean
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-8
)
).
mean
()
idx
=
err
>
1.0
idx
=
err
>
1.0
err
=
err
.
mean
()
err
=
err
.
mean
()
...
@@ -2361,91 +2361,133 @@ def test_normal_map_tree():
...
@@ -2361,91 +2361,133 @@ def test_normal_map_tree():
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
@
pytest
.
mark
.
parametrize
(
"kind"
,
[
'fc1'
,
'fc2'
,
'attn'
,
'attn_packed'
],
ids
=
[
'fc1'
,
'fc2'
,
'attn'
,
'attn_packed'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
ids
=
[
'fp16'
,
'bf16'
,
'fp32'
])
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
):
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
,
kind
):
print
(
''
)
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
,
6144
]:
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#for dim in [4*1024]:
#for dim in [4*1024]:
#for dim in [1*16]:
#for dim in [1*128]:
errs
=
[]
errs1
=
[]
relerrs
=
[]
errs2
=
[]
max_err
=
0
errs3
=
[]
max_relerr
=
0
relerrs1
=
[]
relerrs2
=
[]
relerrs3
=
[]
max_errs1
=
[]
max_errs2
=
[]
max_errs3
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
if
kind
==
'fc1'
:
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
#B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
B
=
torch
.
randn
(
dim
*
4
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
elif
kind
==
'fc2'
:
#print('')
A
=
torch
.
randn
(
1
,
4
*
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
#print(A)
B
=
torch
.
randn
(
dim
,
4
*
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#print(B.t())
elif
kind
==
'attn'
:
#A[:, :-1] = 0
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
#B[:, :-1] = 0
B
=
torch
.
randn
(
dim
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
#A.flatten()[:-1] = 0
elif
kind
==
'attn_packed'
:
#B.flatten()[:-1] = 0
A
=
torch
.
randn
(
1
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
randn
(
dim
*
3
,
dim
,
dtype
=
dtype
,
device
=
'cuda'
)
/
math
.
sqrt
(
dim
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
#F.dequantize_4bit(qB, state)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
#print(state)
err1
=
(
C1
-
C2
).
abs
().
float
()
#print(qB)
err2
=
(
C3
-
C2
).
abs
().
float
()
err3
=
(
C3
-
C1
).
abs
().
float
()
#print('')
mag1
=
torch
.
abs
(
C1
).
float
()
+
1e-5
#print(A)
mag2
=
torch
.
abs
(
C3
).
float
()
+
1e-5
#print(B)
mag3
=
torch
.
abs
(
C3
).
float
()
+
1e-5
#print('='*89)
#print(C3)
relerr1
=
err1
/
mag1
relerr2
=
err2
/
mag2
#print(C1.shape, C2.shape)
relerr3
=
err3
/
mag3
# tensor cores are non-deterministic
max_err1
=
err1
.
max
()
# so we need to analyze errors around the mean
max_err2
=
err2
.
max
()
# to test our implementation
max_err3
=
err3
.
max
()
err
=
torch
.
abs
(
C1
-
C2
).
float
()
mag
=
torch
.
abs
(
C1
).
float
()
+
1e-5
errs1
.
append
(
err1
.
mean
().
item
())
relerr
=
err
/
mag
errs2
.
append
(
err2
.
mean
().
item
())
max_err
=
max
(
err
.
max
(),
max_err
)
errs3
.
append
(
err3
.
mean
().
item
())
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
relerrs1
.
append
(
relerr1
.
mean
().
item
())
relerr
=
relerr
.
mean
().
item
()
relerrs2
.
append
(
relerr2
.
mean
().
item
())
#print(err)
relerrs3
.
append
(
relerr3
.
mean
().
item
())
errs
.
append
(
err
)
max_errs1
.
append
(
max_err1
.
item
())
relerrs
.
append
(
relerr
)
max_errs2
.
append
(
max_err2
.
item
())
max_errs3
.
append
(
max_err3
.
item
())
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
int
(
C1
.
numel
()
*
0.0014
*
(
dim
/
256
))
+
1
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
c
=
assert_all_approx_close
(
C1
,
C2
,
1e-5
,
0.01
,
count
=
c
,
throw
=
False
)
#print('')
err1
=
sum
(
errs1
)
/
len
(
errs1
)
/
math
.
sqrt
(
dim
)
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
err2
=
sum
(
errs2
)
/
len
(
errs2
)
/
math
.
sqrt
(
dim
)
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
err3
=
sum
(
errs3
)
/
len
(
errs3
)
/
math
.
sqrt
(
dim
)
#print(dim, (max_err.item(), max_relerr.item()))
relerr1
=
sum
(
relerrs1
)
/
len
(
relerrs1
)
/
math
.
sqrt
(
dim
)
print
(
C1
.
flatten
()[
-
20
:])
relerr2
=
sum
(
relerrs2
)
/
len
(
relerrs2
)
/
math
.
sqrt
(
dim
)
print
(
C2
.
flatten
()[
-
20
:])
relerr3
=
sum
(
relerrs3
)
/
len
(
relerrs3
)
/
math
.
sqrt
(
dim
)
#print(C1.flatten())
maxerr1
=
sum
(
max_errs1
)
/
len
(
max_errs1
)
/
math
.
sqrt
(
dim
)
#print(C2.flatten())
maxerr2
=
sum
(
max_errs2
)
/
len
(
max_errs2
)
/
math
.
sqrt
(
dim
)
#print(C3.flatten()[-20:])
maxerr3
=
sum
(
max_errs3
)
/
len
(
max_errs3
)
/
math
.
sqrt
(
dim
)
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
dim
)
absratio
=
err2
/
err3
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
dim
)
relratio
=
relerr2
/
relerr3
maxratio
=
relerr2
/
relerr3
# for debugging if the tests fails
#
#print('='*80)
#print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
#print(C1.flatten()[-20:])
#print(C2.flatten()[-20:])
#print(f'inference vs training abs: {err1}')
#print(f'inference vs training rel: {relerr1}')
#print(f'inference vs training max: {maxerr1}')
#print(f'inference vs training vs torch err ratio abs: {absratio}')
#print(f'inference vs training vs torch err ratio rel: {relratio}')
#print(f'inference vs training vs torch err ratio max: {maxratio}')
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
if
dim
<=
512
:
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
assert
err1
<
7e-5
assert
relerr1
<
0.0008
else
:
assert
err1
<
6e-5
assert
relerr1
<
2e-4
assert
absratio
<
1.005
and
absratio
>
0.995
assert
relratio
<
1.005
and
relratio
>
0.995
assert
maxratio
<
1.005
and
maxratio
>
0.995
elif
dtype
==
torch
.
float32
:
elif
dtype
==
torch
.
float32
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-8
if
dim
<=
512
:
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
1e-7
assert
err1
<
5e-8
assert
relerr1
<
1e-6
assert
maxerr1
<
1e-7
else
:
assert
err1
<
5e-8
assert
relerr1
<
8e-6
assert
maxerr1
<
1e-7
assert
absratio
<
1.005
and
absratio
>
0.995
assert
relratio
<
1.005
and
relratio
>
0.995
assert
maxratio
<
1.005
and
maxratio
>
0.995
elif
dtype
==
torch
.
bfloat16
:
elif
dtype
==
torch
.
bfloat16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
3e-4
if
dim
<=
512
:
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.003
assert
err1
<
5e-4
assert
relerr1
<
0.007
assert
maxerr1
<
0.015
else
:
assert
err1
<
2e-4
assert
relerr1
<
0.002
assert
maxerr1
<
0.0012
assert
absratio
<
1.005
and
absratio
>
0.995
assert
relratio
<
1.04
and
relratio
>
0.96
assert
maxratio
<
1.02
and
maxratio
>
0.98
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
@
pytest
.
mark
.
skip
(
"Row scale has some bugs for ampere"
)
def
test_managed
():
def
test_managed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment