Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
89876bb0
Unverified
Commit
89876bb0
authored
Feb 01, 2024
by
Younes Belkada
Committed by
GitHub
Feb 01, 2024
Browse files
Merge pull request #876 from wkpark/minimal-win-fix
minimal patch to fix Windows compilation issues
parents
4261d89d
fd319d51
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
8 deletions
+24
-8
csrc/cpu_ops.cpp
csrc/cpu_ops.cpp
+18
-1
csrc/kernels.cu
csrc/kernels.cu
+6
-6
csrc/ops.cuh
csrc/ops.cuh
+0
-1
No files found.
csrc/cpu_ops.cpp
View file @
89876bb0
#include <BinSearch.h>
#include <BinSearch.h>
#ifdef _WIN32
#include <thread>
#else
#include <pthread.h>
#include <pthread.h>
#endif
#include <common.h>
#include <common.h>
using
namespace
BinSearch
;
using
namespace
BinSearch
;
...
@@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
...
@@ -31,7 +35,11 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
for
(
long
long
offset
=
0
;
offset
<
num_blocks
;
offset
+=
thread_wave_size
)
for
(
long
long
offset
=
0
;
offset
<
num_blocks
;
offset
+=
thread_wave_size
)
{
{
long
long
valid_chunks
=
num_blocks
-
offset
>=
thread_wave_size
?
thread_wave_size
:
num_blocks
-
offset
;
long
long
valid_chunks
=
num_blocks
-
offset
>=
thread_wave_size
?
thread_wave_size
:
num_blocks
-
offset
;
#ifdef _WIN32
std
::
thread
*
threads
=
(
std
::
thread
*
)
malloc
(
sizeof
(
std
::
thread
)
*
valid_chunks
);
#else
pthread_t
*
threads
=
(
pthread_t
*
)
malloc
(
sizeof
(
pthread_t
)
*
valid_chunks
);
pthread_t
*
threads
=
(
pthread_t
*
)
malloc
(
sizeof
(
pthread_t
)
*
valid_chunks
);
#endif
struct
quantize_block_args
**
args
=
(
quantize_block_args
**
)
malloc
(
valid_chunks
*
sizeof
(
quantize_block_args
*
));
struct
quantize_block_args
**
args
=
(
quantize_block_args
**
)
malloc
(
valid_chunks
*
sizeof
(
quantize_block_args
*
));
...
@@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
...
@@ -55,14 +63,23 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
arg
->
threadidx
=
block_idx
/
blocksize
;
arg
->
threadidx
=
block_idx
/
blocksize
;
arg
->
blocksize
=
blocksize
;
arg
->
blocksize
=
blocksize
;
#ifdef _WIN32
new
(
&
threads
[
chunks_processed
])
std
::
thread
(
quantize_block
,
arg
);
#else
pthread_create
(
&
threads
[
chunks_processed
],
NULL
,
&
quantize_block
,
(
void
*
)
arg
);
pthread_create
(
&
threads
[
chunks_processed
],
NULL
,
&
quantize_block
,
(
void
*
)
arg
);
#endif
chunks_processed
+=
1
;
chunks_processed
+=
1
;
if
(
chunks_processed
==
valid_chunks
){
break
;
}
if
(
chunks_processed
==
valid_chunks
){
break
;
}
}
}
for
(
int
i
=
0
;
i
<
valid_chunks
;
i
++
)
for
(
int
i
=
0
;
i
<
valid_chunks
;
i
++
)
{
#ifdef _WIN32
threads
[
i
].
join
();
#else
int
err
=
pthread_join
(
threads
[
i
],
NULL
);
int
err
=
pthread_join
(
threads
[
i
],
NULL
);
#endif
}
free
(
threads
);
free
(
threads
);
for
(
int
i
=
0
;
i
<
valid_chunks
;
i
++
)
for
(
int
i
=
0
;
i
<
valid_chunks
;
i
++
)
free
(
args
[
i
]);
free
(
args
[
i
]);
...
...
csrc/kernels.cu
View file @
89876bb0
...
@@ -3821,12 +3821,12 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N
...
@@ -3821,12 +3821,12 @@ template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kExtractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rowsA
,
int
colsA
,
int
tiledRowsA
,
int
tiledColsA
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
8
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
8
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
__restrict__
const
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
16
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
16
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
__restrict__
const
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
32
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
half
,
32
,
16
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
half
*
B
,
half
*
out
,
float
*
__restrict__
const
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
8
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
8
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
__restrict__
const
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
16
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
16
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
__restrict__
const
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
32
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kspmm_coo_very_sparse_naive
<
signed
char
,
32
,
8
>(
int
*
max_count
,
int
*
max_idx
,
int
*
offset_rowidx
,
int
*
rowidx
,
int
*
colidx
,
half
*
values
,
signed
char
*
B
,
half
*
out
,
float
*
__restrict__
const
dequant_stats
,
int
nnz
,
int
rowsA
,
int
rowsB
,
int
colsB
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL32
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
0
,
COL32
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL32
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
template
__global__
void
kTransformRowToFormat
<
256
,
8
,
32
,
32
*
8
,
1
,
COL32
>(
char
*
__restrict__
const
A
,
char
*
out
,
int
rows
,
int
cols
,
int
tiledCols
,
int
outRows
,
int
outCols
);
...
...
csrc/ops.cuh
View file @
89876bb0
...
@@ -9,7 +9,6 @@
...
@@ -9,7 +9,6 @@
#include <stdio.h>
#include <stdio.h>
#include <iostream>
#include <iostream>
#include <unistd.h>
#include <assert.h>
#include <assert.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment