Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
bcab99ec
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "2eaeb2a8cd17880360c3278ece5aabbdaa9dc397"
Commit
bcab99ec
authored
Jul 26, 2022
by
Tim Dettmers
Browse files
Working outlier extraction for Turing.
parent
cbb901ac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
17 deletions
+74
-17
csrc/kernels.cu
csrc/kernels.cu
+58
-3
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+2
-3
tests/test_functional.py
tests/test_functional.py
+13
-10
No files found.
csrc/kernels.cu
View file @
bcab99ec
...
@@ -2592,16 +2592,71 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
...
@@ -2592,16 +2592,71 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
}
}
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
)
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
)
{
{
int
local_colidx
=
idx
[
blockIdx
.
x
];
if
(
FORMAT
==
COL_TURING
)
{
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// columns are grouped in increments of 4, meaning that one has the following rows and columns
// rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
// cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
// each thread reads 1 element = 1 row
for
(
int
row
=
threadIdx
.
x
;
row
<
rowsA
;
row
+=
blockDim
.
x
)
{
int
offset_per_col_tile
=
((
rowsA
+
7
)
/
8
)
*
32
*
8
;
int
tile_offset_rows
=
(
row
/
8
)
*
32
*
8
;
int
tile_offset_cols
=
(
local_colidx
/
32
)
*
offset_per_col_tile
;
int
offset
=
0
;
int
subtile_col_idx
=
local_colidx
%
32
;
int
subtile_row_idx
=
row
%
8
;
if
(
row
%
2
==
1
)
offset
+=
128
+
(
subtile_col_idx
/
4
)
*
16
+
(
subtile_col_idx
%
4
)
+
((
subtile_row_idx
-
1
)
*
2
);
else
// even
offset
+=
0
+
(
subtile_col_idx
/
4
)
*
16
+
(
subtile_col_idx
%
4
)
+
(
subtile_row_idx
*
2
);
offset
+=
tile_offset_rows
+
tile_offset_cols
;
char
val
=
0
;
//printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
if
(
offset
>
tiledColsA
*
tiledRowsA
)
printf
(
"(%i (%i %i) (%i %i)
\n
"
,
offset
,
tile_offset_rows
,
tile_offset_cols
,
row
,
local_colidx
);
else
val
=
A
[
offset
];
int
out_idx
=
(
row
*
idx_size
)
+
blockIdx
.
x
;
//if(out_idx > colsA*idx_size)
if
(
val
!=
0
)
{
//printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val);
out
[
out_idx
]
=
val
;
}
else
{
out
[
out_idx
]
=
val
;
}
}
}
}
}
//==============================================================
//==============================================================
// TEMPLATE DEFINITIONS
// TEMPLATE DEFINITIONS
//==============================================================
//==============================================================
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
8
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
8
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
16
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
16
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
...
...
csrc/kernels.cuh
View file @
bcab99ec
...
@@ -118,7 +118,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
...
@@ -118,7 +118,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
TRANSPOSE
,
int
FORMAT
>
__global__
void
kTransformRowToFormat
(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
<
int
THREADS
,
int
ITEMS_PER_THREAD
,
int
TILE_ROWS
,
int
TILE_COLS
,
int
TRANSPOSE
,
int
FORMAT
>
__global__
void
kTransformRowToFormat
(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
<
int
FORMAT
>
__global__
void
kExtractOutliers
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
#endif
#endif
...
...
csrc/ops.cu
View file @
bcab99ec
...
@@ -586,8 +586,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -586,8 +586,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
int
tiledCols
=
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
tiledCols
=
tiledCols
=
fill_up_to_nearest_multiple
(
cols
,
32
);
int
tiledRows
=
0
;
int
tiledRows
=
0
;
int
elements
=
idx_size
*
cols
;
// matrix A is transposed, so we extract columns
int
num_blocks
=
idx_size
;
int
num_blocks
=
(
elements
+
threads
-
1
)
/
threads
;
if
(
FORMAT
==
COL_TURING
)
if
(
FORMAT
==
COL_TURING
)
{
{
...
@@ -598,7 +597,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
...
@@ -598,7 +597,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
tiledRows
=
fill_up_to_nearest_multiple
(
rows
,
32
);
}
}
kExtractOutliers
<
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
idx
,
out
,
rows
,
cols
,
tiledRows
,
tiledCols
);
kExtractOutliers
<
FORMAT
><<<
num_blocks
,
threads
>>>
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
,
tiledRows
,
tiledCols
);
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
CUDA_CHECK_RETURN
(
cudaPeekAtLastError
());
}
}
...
...
tests/test_functional.py
View file @
bcab99ec
...
@@ -1858,18 +1858,21 @@ def test_zp():
...
@@ -1858,18 +1858,21 @@ def test_zp():
def
test_extract_outliers
():
def
test_extract_outliers
():
shapeA
=
(
128
,
128
)
for
i
in
range
(
k
):
idx
=
torch
.
randint
(
0
,
shapeA
[
1
],
size
=
(
10
,)).
int
(
)
shapeA
=
(
4096
,
4
*
4096
)
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
idx
=
torch
.
unique
(
torch
.
randint
(
0
,
shapeA
[
1
],
size
=
(
10
,)).
int
()).
cuda
(
)
outliers1
=
A
[:,
idx
.
long
()]
#idx = torch.Tensor([32]).int().cuda()
A
=
torch
.
randint
(
-
128
,
127
,
size
=
shapeA
,
device
=
'cuda'
).
to
(
torch
.
int8
)
CA
,
SA
=
F
.
transform
(
A
,
'col_turing'
)
outliers1
=
A
[:,
idx
.
long
()]
outliers2
=
F
.
ex
tra
ct_outliers
(
CA
,
SA
,
idx
)
CA
,
SA
=
F
.
tra
nsform
(
A
,
'col_turing'
)
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
outliers2
=
F
.
extract_outliers
(
CA
,
SA
,
idx
)
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
assert
outliers2
.
shape
[
0
]
==
shapeA
[
0
]
assert
outliers2
.
shape
[
1
]
==
idx
.
numel
()
#print(outliers1)
#print(outliers2)
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
torch
.
testing
.
assert_allclose
(
outliers1
,
outliers2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment