Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cbb901ac
Commit
cbb901ac
authored
Jul 26, 2022
by
Tim Dettmers
Browse files
Boilerplate and test for extract_outliers.
parent
c771b3a7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
87 additions
and
0 deletions
+87
-0
bitsandbytes/functional.py
bitsandbytes/functional.py
+26
-0
csrc/kernels.cu
csrc/kernels.cu
+7
-0
csrc/kernels.cuh
csrc/kernels.cuh
+2
-0
csrc/ops.cu
csrc/ops.cu
+27
-0
csrc/ops.cuh
csrc/ops.cuh
+2
-0
csrc/pythonInterface.c
csrc/pythonInterface.c
+6
-0
tests/test_functional.py
tests/test_functional.py
+17
-0
No files found.
bitsandbytes/functional.py
View file @
cbb901ac
...
...
@@ -1409,3 +1409,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
x
*=
SA
[
1
]
/
127
x
+=
offset
return
x
.
to
(
dtype
)
def
extract_outliers
(
A
,
SA
,
idx
):
shapeA
=
SA
[
0
]
formatA
=
SA
[
1
]
assert
formatA
in
[
'col_turing'
,
'col_ampere'
]
assert
A
.
device
.
type
==
'cuda'
out
=
torch
.
zeros
((
shapeA
[
0
],
idx
.
numel
()),
dtype
=
torch
.
int8
,
device
=
A
.
device
)
idx_size
=
ct
.
c_int32
(
idx
.
numel
())
rows
=
ct
.
c_int32
(
shapeA
[
0
])
cols
=
ct
.
c_int32
(
shapeA
[
1
])
ptrA
=
get_ptr
(
A
)
ptrIdx
=
get_ptr
(
idx
)
ptrOut
=
get_ptr
(
out
)
if
formatA
==
'col_turing'
:
lib
.
cextractOutliers_turing
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
elif
formatA
==
'col_ampere'
:
lib
.
cextractOutliers_ampere
(
ptrA
,
ptrIdx
,
ptrOut
,
idx_size
,
rows
,
cols
)
return
out
csrc/kernels.cu
View file @
cbb901ac
...
...
@@ -2592,10 +2592,17 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
)
{
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
8
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
16
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
32
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
...
...
csrc/kernels.cuh
View file @
cbb901ac
...
...
@@ -118,6 +118,8 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
TRANSPOSE
,
int
FORMAT
>
__global__
void
kTransformRowToFormat
(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
#endif
csrc/ops.cu
View file @
cbb901ac
...
...
@@ -578,10 +578,37 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
template
<
int
FORMAT
>
void
extractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
)
{
int
threads
=
256
;
// we load 128 column values per warp
int
tiledCols
=
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
tiledRows
=
0
;
int
elements
=
idx_size
*
cols
;
// matrix A is transposed, so we extract columns
int
num_blocks
=
(
elements
+
threads
-
1
)
/
threads
;
if
(
FORMAT
==
COL_TURING
)
{
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
8
);
}
else
if
(
FORMAT
==
COL_AMPERE
)
{
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
}
kExtractOutliers
<
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
idx
,
out
,
rows
,
cols
,
tiledRows
,
tiledCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
spmm_coo_very_sparse_naive
<
half
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
void
spmm_coo_very_sparse_naive
<
signed
char
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
...
...
csrc/ops.cuh
View file @
cbb901ac
...
...
@@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
template
<
typename
T
,
int
BITS
>
void
spmm_coo_very_sparse_naive
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
T
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
<
int
FORMAT
>
void
extractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
#endif
csrc/pythonInterface.c
View file @
cbb901ac
...
...
@@ -106,6 +106,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo
void
transform_row2ampere
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_AMPERE
,
0
>
(
A
,
out
,
rows
,
cols
);
}
void
transform_row2ampereT
(
char
*
A
,
char
*
out
,
int
rows
,
int
cols
){
transformRowToFormat
<
COL_AMPERE
,
1
>
(
A
,
out
,
rows
,
cols
);
}
void
extractOutliers_turing
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers
<
COL_TURING
>
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
extractOutliers_ampere
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers
<
COL_AMPERE
>
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
int
igemmlt_turing_32
(
cublasLtHandle_t
ltHandle
,
int
m
,
int
n
,
int
k
,
const
int8_t
*
A
,
const
int8_t
*
B
,
void
*
C
,
float
*
row_scale
,
int
lda
,
int
ldb
,
int
ldc
)
{
return
igemmlt
<
COL_TURING
,
32
,
0
>
(
ltHandle
,
m
,
n
,
k
,
A
,
B
,
C
,
row_scale
,
lda
,
ldb
,
ldc
);
}
...
...
@@ -280,6 +283,9 @@ extern "C"
void
cspmm_coo_very_sparse_naive_int8
(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz_rows
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
)
{
spmm_coo_very_sparse_naive_int8
(
max_count
,
max_idx
,
offset_rowidx
,
rowidx
,
colidx
,
values
,
B
,
out
,
dequant_stats
,
nnz_rows
,
nnz
,
rowsA
,
rowsB
,
colsB
);
}
void
cextractOutliers_turing
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_turing
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
cextractOutliers_ampere
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_ampere
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
#endif
void
cquantize_blockwise_cpu_fp32
(
float
*
code
,
float
*
A
,
float
*
absmax
,
unsigned
char
*
out
,
const
int
n
){
quantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
void
cdequantize_blockwise_cpu_fp32
(
float
*
code
,
unsigned
char
*
A
,
float
*
absmax
,
float
*
out
,
const
int
n
){
dequantize_cpu
(
code
,
A
,
absmax
,
out
,
n
);
}
...
...
tests/test_functional.py
View file @
cbb901ac
...
...
@@ -1856,3 +1856,20 @@ def test_zp():
print
(
err1
,
err2
,
err3
,
err4
,
err5
,
err6
)
def
test_extract_outliers
():
shapeA
=
(
128
,
128
)
idx
=
torch
.
randint
(
0
,
shapeA
[
1
],
size
=
(
10
,)).
int
()
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
outliers1
=
A
[:,
idx
.
long
()]
CA
,
SA
=
F
.
transform
(
A
,
'col_turing'
)
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment