Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1ed2fa2f
"csrc/pythonInterface.cpp" did not exist on "8258b4364a21a4da2572cb644d0926080c3268da"
Commit
1ed2fa2f
authored
Aug 16, 2022
by
Tim Dettmers
Browse files
Removed storage() from get_ptr; added boilerplate for bias dequant_mm.
parent
26efb154
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
27 additions
and
21 deletions
+27
-21
bitsandbytes/functional.py
bitsandbytes/functional.py
+8
-3
csrc/kernels.cu
csrc/kernels.cu
+2
-2
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+2
-3
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+2
-2
tests/test_functional.py
tests/test_functional.py
+11
-9
No files found.
bitsandbytes/functional.py
View file @
1ed2fa2f
...
@@ -218,7 +218,7 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
...
@@ -218,7 +218,7 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
if
A
is
None
:
if
A
is
None
:
return
None
return
None
else
:
else
:
return
ct
.
c_void_p
(
A
.
data
.
storage
().
data_ptr
())
return
ct
.
c_void_p
(
A
.
data
.
data_ptr
())
def
pre_call
(
device
):
def
pre_call
(
device
):
...
@@ -1407,8 +1407,10 @@ def mm_dequant(
...
@@ -1407,8 +1407,10 @@ def mm_dequant(
out
=
None
,
out
=
None
,
new_row_stats
=
None
,
new_row_stats
=
None
,
new_col_stats
=
None
,
new_col_stats
=
None
,
bias
=
None
):
):
assert
A
.
dtype
==
torch
.
int32
assert
A
.
dtype
==
torch
.
int32
if
bias
is
not
None
:
assert
bias
.
dtype
==
torch
.
float16
out_shape
=
quant_state
[
0
]
out_shape
=
quant_state
[
0
]
if
len
(
out_shape
)
==
3
:
if
len
(
out_shape
)
==
3
:
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
out_shape
=
(
out_shape
[
0
]
*
out_shape
[
1
],
out_shape
[
2
])
...
@@ -1430,17 +1432,20 @@ def mm_dequant(
...
@@ -1430,17 +1432,20 @@ def mm_dequant(
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
]
new_col_stats
.
shape
[
0
]
==
col_stats
.
shape
[
0
]
),
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
),
f
"
{
new_col_stats
.
shape
}
vs
{
col_stats
.
shape
}
"
prev_device
=
pre_call
(
A
.
device
)
ptrA
=
get_ptr
(
A
)
ptrA
=
get_ptr
(
A
)
ptrOut
=
get_ptr
(
out
)
ptrOut
=
get_ptr
(
out
)
ptrRowStats
=
get_ptr
(
row_stats
)
ptrRowStats
=
get_ptr
(
row_stats
)
ptrColStats
=
get_ptr
(
col_stats
)
ptrColStats
=
get_ptr
(
col_stats
)
ptrNewRowStats
=
get_ptr
(
new_row_stats
)
ptrNewRowStats
=
get_ptr
(
new_row_stats
)
ptrNewColStats
=
get_ptr
(
new_col_stats
)
ptrNewColStats
=
get_ptr
(
new_col_stats
)
ptrBias
=
get_ptr
(
bias
)
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numRows
=
ct
.
c_int32
(
out_shape
[
0
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
numCols
=
ct
.
c_int32
(
out_shape
[
1
])
is_on_gpu
([
A
,
row_stats
,
col_stats
,
out
,
new_row_stats
,
new_col_stats
])
is_on_gpu
([
A
,
row_stats
,
col_stats
,
out
,
new_row_stats
,
new_col_stats
,
bias
])
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
numRows
,
numCols
)
lib
.
cdequant_mm_int32_fp16
(
ptrA
,
ptrRowStats
,
ptrColStats
,
ptrOut
,
ptrNewRowStats
,
ptrNewColStats
,
ptrBias
,
numRows
,
numCols
)
post_call
(
prev_device
)
return
out
return
out
...
...
csrc/kernels.cu
View file @
1ed2fa2f
...
@@ -1889,7 +1889,7 @@ template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __rest
...
@@ -1889,7 +1889,7 @@ template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __rest
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
)
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
half
*
__restrict__
const
bias
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
)
{
{
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
...
@@ -2675,7 +2675,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
...
@@ -2675,7 +2675,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL_AMPERE
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL_AMPERE
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL_AMPERE
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL_AMPERE
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kdequant_mm_int32_fp16
<
4
,
128
,
512
>(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
template
__global__
void
kdequant_mm_int32_fp16
<
4
,
128
,
512
>(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
half
*
__restrict__
const
bias
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
template
__global__
void
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
0
>(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
__global__
void
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
0
>(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
__global__
void
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
>(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
__global__
void
kDoubleRowColQuant
<
64
,
4
,
16
,
64
*
4
,
1
>(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
...
...
csrc/kernels.cuh
View file @
1ed2fa2f
...
@@ -111,7 +111,7 @@ template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_s
...
@@ -111,7 +111,7 @@ template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_s
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
template
<
int
ITEMS_PER_THREAD
,
int
SUBTILE_ROWS
,
int
THREADS
>
__global__
void
kdequant_mm_int32_fp16
(
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
int
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
half
*
__restrict__
const
bias
,
const
int
numRows
,
const
int
numCols
,
const
int
tileCols
,
const
int
n
);
template
<
typename
T
,
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kgetColRowStats
(
T
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
);
template
<
typename
T
,
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kgetColRowStats
(
T
*
__restrict__
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
,
int
tiledRows
,
int
tiledCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kDoubleRowColQuant
(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
SPARSE_DECOMP
>
__global__
void
kDoubleRowColQuant
(
half
*
__restrict__
const
A
,
float
*
__restrict__
const
rowStats
,
float
*
__restrict__
const
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
__restrict__
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
,
int
tiledCols
);
...
...
csrc/ops.cu
View file @
1ed2fa2f
...
@@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple)
...
@@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple)
return
value
+
(
value
%
multiple
==
0
?
0
:
(
multiple
-
(
value
%
multiple
)));
return
value
+
(
value
%
multiple
==
0
?
0
:
(
multiple
-
(
value
%
multiple
)));
}
}
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
)
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
half
*
bias
,
int
numRows
,
int
numCols
)
{
{
int
threads
=
512
;
int
threads
=
512
;
int
tileCols
=
fill_up_to_nearest_multiple
(
numCols
,
32
);
int
tileCols
=
fill_up_to_nearest_multiple
(
numCols
,
32
);
...
@@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
...
@@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
num_blocks
=
num_blocks
*
(
tileCols
/
32
);
num_blocks
=
num_blocks
*
(
tileCols
/
32
);
assert
(
threads
<=
tilesize
);
assert
(
threads
<=
tilesize
);
kdequant_mm_int32_fp16
<
4
,
128
,
512
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
,
tileCols
,
n
);
kdequant_mm_int32_fp16
<
4
,
128
,
512
><<<
num_blocks
,
threads
>>>
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
bias
,
numRows
,
numCols
,
tileCols
,
n
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
@@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
...
@@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
col_tiles
=
col_tiles
>
0
?
col_tiles
:
1
;
col_tiles
=
col_tiles
>
0
?
col_tiles
:
1
;
int
num_blocks
=
row_tiles
*
col_tiles
;
int
num_blocks
=
row_tiles
*
col_tiles
;
if
(
nnz_threshold
==
0.0
)
if
(
nnz_threshold
==
0.0
)
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
0
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
kgetColRowStats
<
half
,
STATS_THREADS
,
STATS_ITEMS
,
STATS_ROWS
,
STATS_THREADS
*
STATS_ITEMS
,
0
><<<
num_blocks
,
STATS_THREADS
>>>
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
,
tiledRows
,
tiledCols
);
else
if
(
nnz_threshold
!=
0.0
)
else
if
(
nnz_threshold
!=
0.0
)
...
...
csrc/ops.cuh
View file @
1ed2fa2f
...
@@ -163,7 +163,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
...
@@ -163,7 +163,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
template
<
typename
T
,
int
SRC
,
int
TARGET
,
bool
transpose
,
int
DTYPE
>
void
transform
(
cublasLtHandle_t
ltHandle
,
T
*
A
,
T
*
out
,
int
dim1
,
int
dim2
);
template
<
typename
T
,
int
SRC
,
int
TARGET
,
bool
transpose
,
int
DTYPE
>
void
transform
(
cublasLtHandle_t
ltHandle
,
T
*
A
,
T
*
out
,
int
dim1
,
int
dim2
);
void
cutlass_igemm
(
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
);
void
cutlass_igemm
(
bool
transposeA
,
bool
transposeB
,
int
m
,
int
n
,
int
k
,
void
*
A
,
void
*
B
,
void
*
C
,
int
lda
,
int
ldb
,
int
ldc
);
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
);
void
dequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
half
*
bias
,
int
numRows
,
int
numCols
);
void
getColRowStats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
);
void
getColRowStats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
);
void
doubleRowColQuant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
void
doubleRowColQuant
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
char
*
out_col_normed
,
char
*
out_row_normed
,
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
);
int
*
rowidx
,
int
*
colidx
,
half
*
val
,
int
*
nnz_block_ptr
,
float
threshold
,
int
rows
,
int
cols
);
...
...
csrc/pythonInterface.c
View file @
1ed2fa2f
...
@@ -248,8 +248,8 @@ extern "C"
...
@@ -248,8 +248,8 @@ extern "C"
MAKE_FUNC_CTRANSFORM
(
8
,
col32
,
row
,
n
,
int8_t
,
COL32
,
ROW
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
8
,
col32
,
row
,
n
,
int8_t
,
COL32
,
ROW
,
false
,
8
)
MAKE_FUNC_CTRANSFORM
(
32
,
col32
,
row
,
n
,
int32_t
,
COL32
,
ROW
,
false
,
32
)
MAKE_FUNC_CTRANSFORM
(
32
,
col32
,
row
,
n
,
int32_t
,
COL32
,
ROW
,
false
,
32
)
void
cdequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
int
numRows
,
int
numCols
)
void
cdequant_mm_int32_fp16
(
int
*
A
,
float
*
rowStats
,
float
*
colStats
,
half
*
out
,
float
*
newRowStats
,
float
*
newcolStats
,
half
*
bias
,
int
numRows
,
int
numCols
)
{
dequant_mm_int32_fp16
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
numRows
,
numCols
);
}
{
dequant_mm_int32_fp16
(
A
,
rowStats
,
colStats
,
out
,
newRowStats
,
newcolStats
,
bias
,
numRows
,
numCols
);
}
void
cget_col_row_stats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
)
void
cget_col_row_stats
(
half
*
A
,
float
*
rowStats
,
float
*
colStats
,
int
*
nnz_count_row
,
float
nnz_threshold
,
int
rows
,
int
cols
)
{
getColRowStats
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
);
}
{
getColRowStats
(
A
,
rowStats
,
colStats
,
nnz_count_row
,
nnz_threshold
,
rows
,
cols
);
}
...
...
tests/test_functional.py
View file @
1ed2fa2f
...
@@ -961,20 +961,24 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist()
...
@@ -961,20 +961,24 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist()
dims
=
(
2
,)
dims
=
(
2
,)
# ldb = list(range(256, 1*1024, 256))
# ldb = list(range(256, 1*1024, 256))
formatB
=
[
"col_turing"
,
"col_ampere"
]
formatB
=
[
"col_turing"
,
"col_ampere"
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
))
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim4
,
dims
,
formatB
,
has_bias
))
names
=
[
names
=
[
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}"
.
format
(
*
vals
)
for
vals
in
values
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}
_has_bias_{4}
"
.
format
(
*
vals
)
for
vals
in
values
]
]
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim4, dims, formatB
, has_bias
"
,
values
,
ids
=
names
)
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
):
def
test_dequant_mm
(
dim1
,
dim4
,
dims
,
formatB
,
has_bias
):
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
inner
=
torch
.
randint
(
1
,
128
,
size
=
(
1
,)).
item
()
bias
=
None
if
has_bias
:
bias
=
torch
.
randn
(
dim4
,
device
=
'cuda'
,
dtype
=
torch
.
float16
)
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
for
i
in
range
(
k
):
for
i
in
range
(
k
):
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
A
=
torch
.
randn
(
dim1
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
B
=
torch
.
randn
(
dim4
,
inner
,
device
=
"cuda"
)
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
C1
=
torch
.
matmul
(
A
.
half
(),
B
.
t
().
half
())
if
has_bias
:
C1
+=
bias
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
A1
,
maxA
=
F
.
vectorwise_quant
(
A
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
B1
,
maxB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
)
...
@@ -985,17 +989,15 @@ def test_dequant_mm(dim1, dim4, dims, formatB):
...
@@ -985,17 +989,15 @@ def test_dequant_mm(dim1, dim4, dims, formatB):
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C3
,
S
=
F
.
nvidia_transform
(
C2
,
"row"
,
state
=
SC
)
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
C4
=
F
.
vectorwise_mm_dequant
(
C3
.
float
(),
maxA
,
maxB
.
t
())
if
has_bias
:
C4
+=
bias
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
count
=
(
torch
.
isclose
(
C1
,
C4
,
atol
=
0.01
,
rtol
=
0.1
)
==
0
).
sum
().
item
()
n
=
C1
.
numel
()
n
=
C1
.
numel
()
p
=
0.06
p
=
0.06
assert
(
assert
(
count
/
n
<
p
),
f
"error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
"
count
/
n
<
p
),
f
"error in more than
{
p
}
of elements:
{
count
}
/
{
n
}
=
{
count
/
n
}
"
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
())
C5
=
F
.
mm_dequant
(
C2
,
SC
,
maxA
.
flatten
(),
maxB
.
flatten
()
,
bias
=
bias
)
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
torch
.
testing
.
assert_allclose
(
C5
,
C4
)
# print(C2)
n
=
2
n
=
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment