Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
30d03e02
Commit
30d03e02
authored
Apr 30, 2023
by
Tim Dettmers
Browse files
64 threads, high smem, 434.
parent
e01d4e03
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
25 deletions
+26
-25
csrc/kernels.cu
csrc/kernels.cu
+24
-24
csrc/ops.cu
csrc/ops.cu
+2
-1
No files found.
csrc/kernels.cu
View file @
30d03e02
...
...
@@ -3041,7 +3041,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
}
}
#define WARPS
1
#define WARPS
2
template
<
typename
T
,
int
BITS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
...
...
@@ -3062,10 +3062,11 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
const
int
a_tile_offset
=
32
*
16
+
16
;
const
int
b_tile_offset
=
16
*
8
+
16
;
const
int
c_tile_offset
=
32
*
8
+
24
;
__shared__
T
smem_A
[
WARPS
*
32
*
16
*
2
+
(
16
*
1
)];
__shared__
T
smem_B
[
WARPS
*
16
*
8
*
2
+
(
16
*
1
)];
__shared__
T
smem_C
[
WARPS
*
32
*
8
];
__shared__
T
smem_A
[
WARPS
*
32
*
16
*
2
+
(
16
*
(
WARPS
-
1
)
)];
__shared__
T
smem_B
[
WARPS
*
16
*
8
*
2
+
(
16
*
(
WARPS
-
1
)
)];
__shared__
T
smem_C
[
WARPS
*
32
*
8
+
(
24
*
(
WARPS
-
1
))
];
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
32
,
8
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
...
...
@@ -3092,46 +3093,45 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//int block_idx = 0;
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
32
)
for
(
int
base_idx
=
0
;
base_idx
<
K
;
base_idx
+=
blockDim
.
x
)
{
int
idx
=
base_idx
+
threadIdx
.
x
;
if
(
idx
>=
K
)
{
smem_A
[
threadIdx
.
x
]
=
0.0
f
;
//smem_B[threadIdx.x] = 0.0f;
}
else
{
smem_A
[
half_warp_lane
+
(
half_warp_id
*
a_tile_offset
)]
=
A
[
idx
];
if
(
idx
>=
K
)
{
smem_A
[
threadIdx
.
x
]
=
0.0
f
;
//smem_B[threadIdx.x] = 0.0f;
}
else
{
smem_A
[
half_warp_lane
+
(
half_warp_id
*
a_tile_offset
)]
=
A
[
idx
];
for
(
int
col
=
0
;
col
<
8
;
col
++
)
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
}
for
(
int
col
=
0
;
col
<
8
;
col
++
)
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
}
__syncthreads
();
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[
0
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[
0
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
a2_frag
,
&
(
smem_A
[
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b2_frag
,
&
(
smem_B
[
b_tile_offset
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
a2_frag
,
&
(
smem_A
[
half_warp_id
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b2_frag
,
&
(
smem_B
[
half_warp_id
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a2_frag
,
b2_frag
,
c_frag
);
}
// 129 mu
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
8
,
wmma
::
mem_row_major
);
wmma
::
store_matrix_sync
(
&
(
smem_C
[
half_warp_id
*
c_tile_offset
]),
c_frag
,
8
,
wmma
::
mem_row_major
);
__syncthreads
();
//if(threadIdx.x >= 16){ return; }
//printf("%i %f\n", threadIdx.x, (float)smem_C[threadIdx.x]);
//if(threadIdx.x < 32)
//
if(warp_lane < 8 && warp_id > 0)
//
//local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
//
atomicAdd(&(smem_C[warp_lane]), smem_C[warp_lane + (warp_id*
32*8
)]);
//
__syncthreads();
if
(
half_
warp_lane
<
8
&&
half_
warp_id
>
0
)
//local_C[warp_lane] = smem_C[warp_lane + (warp_id*32*8)];
atomicAdd
(
&
(
smem_C
[
half_
warp_lane
]),
smem_C
[
half_
warp_lane
+
(
half_
warp_id
*
c_tile_offset
)]);
__syncthreads
();
//local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
//if(threadIdx.x == 0)
...
...
csrc/ops.cu
View file @
30d03e02
...
...
@@ -693,7 +693,8 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
if
(
bits
==
16
)
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device
<
T
,
16
,
32
><<<
num_blocks
,
32
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
gemm_device
<
T
,
16
,
64
><<<
num_blocks
,
32
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
template
<
typename
T
>
void
gemm_4bit_inference
(
int
m
,
int
n
,
int
k
,
T
*
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment