Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
3d4a2ead
Commit
3d4a2ead
authored
May 01, 2023
by
Tim Dettmers
Browse files
16x16 240.
parent
7cc8ff47
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
27 deletions
+27
-27
csrc/kernels.cu
csrc/kernels.cu
+26
-26
csrc/ops.cu
csrc/ops.cu
+1
-1
No files found.
csrc/kernels.cu
View file @
3d4a2ead
...
...
@@ -3052,37 +3052,37 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//typedef cub::BlockReduce<T, THREADS> BlockReduce;
//// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
int
col_offset
=
blockIdx
.
x
*
8
;
int
col_offset
=
blockIdx
.
x
*
16
;
const
int
warp_id
=
threadIdx
.
x
/
32
;
const
int
half_warp_id
=
threadIdx
.
x
/
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
T
local_A
[
1
];
T
local_B
[
8
];
T
local_B
[
16
];
const
int
a_tile_offset
=
(
32
*
16
+
16
);
const
int
b_tile_offset
=
(
16
*
8
+
16
);
const
int
c_tile_offset
=
32
*
8
+
24
;
const
int
a_tile_offset
=
(
16
*
16
+
16
);
const
int
b_tile_offset
=
(
16
*
16
+
16
);
const
int
c_tile_offset
=
16
*
16
+
24
;
__shared__
T
smem_A
[
2
*
batch_size_warps
*
32
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
8
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_C
[
32
*
8
];
__shared__
T
smem_A
[
2
*
batch_size_warps
*
16
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_B
[
2
*
batch_size_warps
*
16
*
16
+
(
2
*
16
*
(
batch_size_warps
-
1
))];
__shared__
T
smem_C
[
16
*
16
];
wmma
::
fragment
<
wmma
::
matrix_a
,
32
,
8
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
32
,
8
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
32
,
8
,
16
,
half
>
c_frag
;
wmma
::
fragment
<
wmma
::
matrix_a
,
16
,
16
,
16
,
half
,
wmma
::
row_major
>
a_frag
;
wmma
::
fragment
<
wmma
::
matrix_b
,
16
,
16
,
16
,
half
,
wmma
::
col_major
>
b_frag
;
wmma
::
fragment
<
wmma
::
accumulator
,
16
,
16
,
16
,
half
>
c_frag
;
wmma
::
fill_fragment
(
c_frag
,
0.0
f
);
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
16
*
WARPS
;
i
+=
blockDim
.
x
)
smem_A
[
i
]
=
T
(
0
);
//
for(int i = threadIdx.x; i <
16
*16*WARPS; i+=blockDim.x)
//
smem_A[i] = T(0);
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
8
*
WARPS
;
i
+=
blockDim
.
x
)
smem_B
[
i
]
=
T
(
0
);
//
for(int i = threadIdx.x; i <
16*16
*WARPS; i+=blockDim.x)
//
smem_B[i] = T(0);
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
8
*
WARPS
;
i
+=
blockDim
.
x
)
for
(
int
i
=
threadIdx
.
x
;
i
<
16
*
16
;
i
+=
blockDim
.
x
)
smem_C
[
i
]
=
T
(
0
);
__syncthreads
();
...
...
@@ -3099,14 +3099,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{
local_A
[
0
]
=
A
[
idx
];
#pragma unroll
8
for
(
int
col
=
0
;
col
<
8
;
col
++
)
#pragma unroll
16
for
(
int
col
=
0
;
col
<
16
;
col
++
)
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
smem_A
[
half_warp_lane
+
(
half_warp_id
*
a_tile_offset
)]
=
local_A
[
0
];
#pragma unroll
8
for
(
int
col
=
0
;
col
<
8
;
col
++
)
#pragma unroll
16
for
(
int
col
=
0
;
col
<
16
;
col
++
)
smem_B
[
half_warp_lane
+
(
half_warp_id
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
ticktock
=
ticktock
==
0
?
1
:
0
;
...
...
@@ -3120,14 +3120,14 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
{
local_A
[
0
]
=
A
[
idx
];
#pragma unroll
8
for
(
int
col
=
0
;
col
<
8
;
col
++
)
#pragma unroll
16
for
(
int
col
=
0
;
col
<
16
;
col
++
)
local_B
[
col
]
=
B
[(
col_offset
+
col
)
*
ldb
+
idx
];
smem_A
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
a_tile_offset
)]
=
local_A
[
0
];
#pragma unroll
8
for
(
int
col
=
0
;
col
<
8
;
col
++
)
#pragma unroll
16
for
(
int
col
=
0
;
col
<
16
;
col
++
)
smem_B
[
half_warp_lane
+
(((
batch_size_warps
*
ticktock
)
+
half_warp_id
)
*
b_tile_offset
)
+
(
col
*
16
)]
=
local_B
[
col
];
}
ticktock
=
ticktock
==
0
?
1
:
0
;
...
...
@@ -3143,7 +3143,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
// 129 mu
if
(
warp_id
==
(
WARPS
-
1
))
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
8
,
wmma
::
mem_row_major
);
wmma
::
store_matrix_sync
(
&
(
smem_C
[
0
]),
c_frag
,
16
,
wmma
::
mem_row_major
);
__syncthreads
();
//if(threadIdx.x >= 16){ return; }
...
...
@@ -3185,7 +3185,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
//if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
//out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
if
(
threadIdx
.
x
<
8
&&
col_offset
+
threadIdx
.
x
<
M
)
if
(
threadIdx
.
x
<
16
&&
col_offset
+
threadIdx
.
x
<
M
)
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
}
...
...
csrc/ops.cu
View file @
3d4a2ead
...
...
@@ -678,7 +678,7 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
bits
)
{
int
num_blocks
=
(
m
+
7
)
/
8
;
int
num_blocks
=
(
m
+
15
)
/
16
;
cout
<<
num_blocks
<<
endl
;
cout
<<
lda
<<
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment