Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cabcd9b9
Commit
cabcd9b9
authored
Apr 30, 2023
by
Tim Dettmers
Browse files
Halved shared memory 466.
parent
30d03e02
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
27 deletions
+43
-27
csrc/kernels.cu
csrc/kernels.cu
+43
-27
No files found.
csrc/kernels.cu
View file @
cabcd9b9
...
...
@@ -3053,25 +3053,23 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int
col_offset
=
blockIdx
.
x
*
8
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
T
local_A
[
64
/
BITS
];
T
local_B
[
64
/
BITS
];
T
local_C
[
8
];
T
local_A
[
1
];
T
local_B
[
8
];
const
int
a_tile_offset
=
32
*
16
+
16
;
const
int
b_tile_offset
=
16
*
8
+
16
;
const
int
c_tile_offset
=
32
*
8
+
24
;
__shared__
T
smem_A
[
WARPS
*
32
*
16
*
2
+
(
16
*
(
WARPS
-
1
))];
__shared__
T
smem_B
[
WARPS
*
16
*
8
*
2
+
(
16
*
(
WARPS
-
1
))];
__shared__
T
smem_A
[
WARPS
*
32
*
16
+
(
16
*
(
WARPS
-
1
))];
__shared__
T
smem_B
[
WARPS
*
16
*
8
+
(
16
*
(
WARPS
-
1
))];
__shared__
T
smem_C
[
WARPS
*
32
*
8
+
(
24
*
(
WARPS
-
1
))];
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
32
,
8
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a2_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
32
,
8
,
16
,
half
,
wmma
::
col_major
>
b2_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
32
,
8
,
16
,
half
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
...
...
@@ -3087,9 +3085,9 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
smem_C
[
i
]
=
T
(
0
);
__syncthreads
();
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
local_C
[
k
]
=
T
(
0
);
//
#pragma unroll 8
//
for(int k = 0; k < 8; k++)
//
local_C[k] = T(0);
//int block_idx = 0;
//for(int base_idx = 0; base_idx < K; base_idx+=blockDim.x)
...
...
@@ -3097,6 +3095,21 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{
int
idx
=
base_idx
+
threadIdx
.
x
;
for
(
int
k
=
0
;
k
<
2
;
k
++
)
{
if
(
k
==
0
)
{
if
(
idx
<
K
)
{
local_A
[
0
]
=
A
[
idx
];
#pragma unroll 8
for
(
int
col
=
0
;
col
<
8
;
col
++
)
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
}
}
if
(
idx
>=
K
)
{
smem_A
[
threadIdx
.
x
]
=
0.0
f
;
...
...
@@ -3104,20 +3117,23 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
}
else
{
smem_A
[
half_warp_lane
+
(
half_warp_id
*
a_tile_offset
)]
=
A
[
idx
];
if
((
k
==
0
&&
half_warp_id
%
2
==
0
)
||
(
k
==
1
&&
half_warp_id
%
2
==
1
))
{
smem_A
[
half_warp_lane
+
(
warp_id
*
a_tile_offset
)]
=
local_A
[
0
];
#pragma unroll 8
for
(
int
col
=
0
;
col
<
8
;
col
++
)
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
smem_B
[
half_warp_lane
+
(
warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
}
__syncthreads
();
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[
0
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[
0
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
a2_frag
,
&
(
smem_A
[
half_warp_id
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b2_frag
,
&
(
smem_B
[
half_warp_id
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
load_matrix_sync
(
a_frag
,
&
(
smem_A
[
warp_id
*
a_tile_offset
]),
16
);
// 111 mu
wmma
::
load_matrix_sync
(
b_frag
,
&
(
smem_B
[
warp_id
*
b_tile_offset
]),
16
);
// 35 mu
wmma
::
mma_sync
(
c_frag
,
a_frag
,
b_frag
,
c_frag
);
wmma
::
mma_sync
(
c_frag
,
a2_frag
,
b2_frag
,
c_frag
);
}
}
// 129 mu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment