Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
7d2ecd30
Commit
7d2ecd30
authored
Jul 22, 2022
by
Tim Dettmers
Browse files
Fixed rowcol synchronization bug.
parent
c771b3a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
2 deletions
+2
-2
csrc/kernels.cu
csrc/kernels.cu
+2
-2
No files found.
csrc/kernels.cu
View file @
7d2ecd30
...
@@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
...
@@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
__shared__
float
smem_row_absmax_values
[
ITEMS_PER_THREAD
*
THREADS
];
__shared__
float
smem_row_absmax_values
[
ITEMS_PER_THREAD
*
THREADS
];
__shared__
int
smem_row_nnz_values
[
TILE_ROWS
];
__shared__
int
smem_row_nnz_values
[
TILE_ROWS
];
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half
local_data
[
ITEMS_PER_THREAD
];
half
local_data
[
ITEMS_PER_THREAD
];
float
local_data_fp32
[
ITEMS_PER_THREAD
];
float
local_data_fp32
[
ITEMS_PER_THREAD
];
...
@@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
...
@@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
local_col_absmax_values
[
j
]
=
fmaxf
(
local_col_absmax_values
[
j
],
__half2float
(
local_data
[
j
]));
local_col_absmax_values
[
j
]
=
fmaxf
(
local_col_absmax_values
[
j
],
__half2float
(
local_data
[
j
]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
// 3. compute row max (per block); store in smem to accumulate full global mem transation
__syncthreads
();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
#pragma unroll ITEMS_PER_THREAD
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
for
(
int
j
=
0
;
j
<
ITEMS_PER_THREAD
;
j
++
)
local_data_fp32
[
j
]
=
local_data
[
j
];
local_data_fp32
[
j
]
=
local_data
[
j
];
__syncthreads
();
row_absmax
=
(
float
)
BlockRowReduce
(
temp_storage
.
rowreduce
).
Reduce
(
local_data_fp32
,
cub
::
Max
());
row_absmax
=
(
float
)
BlockRowReduce
(
temp_storage
.
rowreduce
).
Reduce
(
local_data_fp32
,
cub
::
Max
());
if
(
SPARSE_DECOMP
)
if
(
SPARSE_DECOMP
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment