Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
f3e97ccb
Commit
f3e97ccb
authored
Apr 28, 2023
by
Tim Dettmers
Browse files
New implementation for batch size 1.
parent
f6df4aef
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
196 additions
and
103 deletions
+196
-103
csrc/kernels.cu
csrc/kernels.cu
+179
-86
csrc/kernels.cuh
csrc/kernels.cuh
+1
-1
csrc/ops.cu
csrc/ops.cu
+5
-5
csrc/ops.cuh
csrc/ops.cuh
+1
-1
csrc/pythonInterface.c
csrc/pythonInterface.c
+4
-4
tests/test_functional.py
tests/test_functional.py
+6
-6
No files found.
csrc/kernels.cu
View file @
f3e97ccb
...
...
@@ -2947,117 +2947,212 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
//// 9. write outputs to matmul output matrix
//}
#define ROWS 2
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
// 1. Load dataB into register
// 2. Dequantize B
// 3. Fetch data from A and multiply
typedef
cub
::
BlockLoad
<
T
,
THREADS
,
ITEMS
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadA
;
//__shared__ typename LoadA::TempStorage loada;
typedef
cub
::
BlockLoad
<
T
,
THREADS
,
ITEMS
,
cub
::
BLOCK_LOAD_WARP_TRANSPOSE
>
LoadB
;
//__shared__ typename LoadB::TempStorage loadb;
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
// Allocate shared memory for BlockReduce
//__shared__ typename BlockReduce::TempStorage reduce;
__shared__
union
{
typename
BlockReduce
::
TempStorage
reduce
;
typename
LoadB
::
TempStorage
loadb
;
typename
LoadA
::
TempStorage
loada
;
}
temp_storage
;
typedef
cub
::
BlockReduce
<
T
,
THREADS
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
reduce
;
int
col_offset
=
blockIdx
.
x
*
8
;
T
dataA
[
ITEMS
];
T
local_B
[
ITEMS
];
T
local_accC
[
ROWS
];
int
valid_items
=
0
;
const
int
col_offset
=
blockIdx
.
x
*
8
;
T
local_A
[
8
];
T
local_B
[
8
];
T
local_C
[
8
];
__shared__
T
tileA
[
ROWS
*
THREADS
*
ITEMS
];
__shared__
T
accumulatorC
[
ROWS
*
8
];
__shared__
T
smem_C
[
8
];
//#pragma unroll 8
//for(int i = 0; i < 8; i++)
// tileA[threadIdx.x + (i*256)] = 0.0f;
//__syncthreads();
if
(
threadIdx
.
x
<
64
)
accumulatorC
[
threadIdx
.
x
]
=
0.0
f
;
if
(
threadIdx
.
x
<
8
)
smem_C
[
threadIdx
.
x
]
=
T
(
0
);
__syncthreads
();
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
local_C
[
k
]
=
T
(
0
);
for
(
int
inner_idx
=
0
;
inner_idx
<
K
;
inner_idx
+=
THREADS
*
ITEMS
)
{
valid_items
=
K
-
inner_idx
>
THREADS
*
ITEMS
?
THREADS
*
ITEMS
:
K
-
inner_idx
;
int
baserow
=
0
;
for
(
int
row
=
baserow
;
row
<
(
baserow
+
ROWS
)
&&
row
<
N
;
row
++
)
{
LoadA
(
temp_storage
.
loada
).
Load
(
&
(
A
[(
row
*
K
)
+
inner_idx
]),
dataA
,
valid_items
,
0.0
f
);
#pragma unroll ITEMS
for
(
int
k
=
0
;
k
<
ITEMS
;
k
++
)
tileA
[
row
*
THREADS
*
ITEMS
+
threadIdx
.
x
+
(
k
*
THREADS
)]
=
dataA
[
k
];
__syncthreads
();
}
baserow
+=
ROWS
;
// load 16 columns from B at a time. B is transposed, so its like loading rows
// each warp loads one row
// each thread loads 128 byte
for
(
int
idx
=
threadIdx
.
x
*
8
;
idx
<
K
;
idx
+=
blockDim
.
x
*
8
)
{
// col: inner_idx + warp_lane
// row: ldb*(offset + warp_id)
for
(
int
col
=
0
;
col
<
8
&&
(
col_offset
+
col
)
<
M
;
col
++
)
if
(
idx
+
8
<=
K
)
reinterpret_cast
<
float4
(
&
)[
8
]
>
(
local_A
)[
0
]
=
reinterpret_cast
<
float4
*>
(
A
)[
idx
/
8
];
else
{
int
colB
=
col_offset
+
col
;
for
(
int
k
=
0
;
k
<
ROWS
;
k
++
)
local_accC
[
k
]
=
0.0
f
;
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
if
(
idx
+
k
<
K
)
local_A
[
k
]
=
A
[
idx
+
k
];
else
local_A
[
k
]
=
0.0
f
;
}
}
int
base_idxB
=
ldb
*
colB
;
valid_items
=
K
-
inner_idx
>
THREADS
*
ITEMS
?
THREADS
*
ITEMS
:
K
-
inner_idx
;
LoadB
(
temp_storage
.
loadb
).
Load
(
&
(
B
[
base_idxB
+
inner_idx
]),
local_B
,
valid_items
,
0.0
f
);
__syncthreads
();
for
(
int
row
=
0
;
row
<
ROWS
&&
row
<
N
;
row
++
)
for
(
int
col
=
0
;
col
<
8
;
col
++
)
{
int
offset_B
=
(
col_offset
+
col
)
*
ldb
;
if
(
idx
+
8
<=
K
)
reinterpret_cast
<
float4
(
&
)[
8
]
>
(
local_B
)[
0
]
=
reinterpret_cast
<
float4
*>
(
B
)[(
offset_B
+
idx
)
/
8
];
else
{
#pragma unroll ITEMS
for
(
int
k
=
0
;
k
<
ITEMS
;
k
++
)
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
int
idxA
=
row
*
THREADS
*
ITEMS
+
threadIdx
.
x
+
(
THREADS
*
k
);
local_accC
[
row
]
+=
tileA
[
idxA
]
*
local_B
[
k
];
if
(
idx
+
k
<
K
)
local_B
[
k
]
=
B
[(
offset_B
+
idx
)
+
k
];
else
local_B
[
k
]
=
0.0
f
;
}
}
local_accC
[
row
]
=
BlockReduce
(
temp_storage
.
reduce
).
Reduce
(
local_accC
[
row
],
cub
::
Sum
());
if
(
threadIdx
.
x
==
0
)
atomicAdd
(
&
accumulatorC
[
row
*
8
+
col
],
local_accC
[
row
]);
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
local_C
[
col
]
+=
local_A
[
k
]
*
local_B
[
k
];
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
// printf("%i %i %f %f %f\n", k, threadIdx.x, (float)local_A[k], (float)local_B[k], (float)local_C[col]);
}
}
}
}
for
(
int
row
=
0
;
row
<
ROWS
&&
row
<
N
;
row
++
)
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
{
int
out_idx
=
ldc
*
row
+
col_offset
;
local_C
[
k
]
=
BlockReduce
(
reduce
).
Reduce
(
local_C
[
k
],
cub
::
Sum
());
__syncthreads
();
}
//if(threadIdx.x < 8)
// if(accumulatorC[row*8 + threadIdx.x] != 0.0)
// printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
if
(
threadIdx
.
x
==
0
)
#pragma unroll 8
for
(
int
k
=
0
;
k
<
8
;
k
++
)
smem_C
[
k
]
=
local_C
[
k
];
else
if
(
threadIdx
.
x
>=
32
)
// early return for unused warps
return
;
if
(
threadIdx
.
x
<
8
&&
(
col_offset
+
threadIdx
.
x
)
<
M
)
{
//printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
out
[
out_idx
+
threadIdx
.
x
]
=
accumulatorC
[
row
*
8
+
threadIdx
.
x
];
}
}
__syncwarp
();
//for(int k = 0; k < 8; k++)
// if((float)local_C[k] != 0.0f)
// printf("%i %f\n", threadIdx.x, (float)local_C[k]);
if
(
threadIdx
.
x
<
8
&&
col_offset
+
threadIdx
.
x
<
M
)
out
[
col_offset
+
threadIdx
.
x
]
=
smem_C
[
threadIdx
.
x
];
}
//#define ROWS 2
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
//{
//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp
//// 1. Load dataB into register
//// 2. Dequantize B
//// 3. Fetch data from A and multiply
//
// typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadA;
// //__shared__ typename LoadA::TempStorage loada;
// typedef cub::BlockLoad<T, THREADS , ITEMS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadB;
// //__shared__ typename LoadB::TempStorage loadb;
// typedef cub::BlockReduce<T, THREADS> BlockReduce;
// // Allocate shared memory for BlockReduce
// //__shared__ typename BlockReduce::TempStorage reduce;
//
// __shared__ union {
// typename BlockReduce::TempStorage reduce;
// typename LoadB::TempStorage loadb;
// typename LoadA::TempStorage loada;
// } temp_storage;
//
//
// T dataA[ITEMS];
// T local_B[ITEMS];
// T local_accC[ROWS];
// int valid_items = 0;
// const int col_offset = blockIdx.x * 8;
//
// __shared__ T tileA[ROWS*THREADS*ITEMS];
// __shared__ T accumulatorC[ROWS*8];
//
// //#pragma unroll 8
// //for(int i = 0; i < 8; i++)
// // tileA[threadIdx.x + (i*256)] = 0.0f;
// //__syncthreads();
// if(threadIdx.x < 64)
// accumulatorC[threadIdx.x] = 0.0f;
// __syncthreads();
//
//
// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS)
// {
// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
// int baserow = 0;
// for(int row = baserow; row < (baserow+ROWS) && row < N; row++)
// {
// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f);
//
// #pragma unroll ITEMS
// for(int k = 0; k < ITEMS; k++)
// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k];
//
// __syncthreads();
// }
// baserow += ROWS;
//
// // load 16 columns from B at a time. B is transposed, so its like loading rows
// // each warp loads one row
// // each thread loads 128 byte
//
// // col: inner_idx + warp_lane
// // row: ldb*(offset + warp_id)
// for(int col = 0; col < 8 && (col_offset + col) < M; col++)
// {
// int colB = col_offset + col;
//
// for(int k = 0; k < ROWS; k++)
// local_accC[k] = 0.0f;
//
// int base_idxB = ldb*colB;
// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx;
// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f);
// __syncthreads();
//
// for(int row = 0; row < ROWS && row < N; row++)
// {
// #pragma unroll ITEMS
// for(int k = 0; k < ITEMS; k++)
// {
// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k);
// local_accC[row] += tileA[idxA]*local_B[k];
// }
//
// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum());
// if(threadIdx.x == 0)
// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]);
// }
// }
// }
//
// for(int row = 0; row < ROWS && row < N; row++)
// {
// int out_idx = ldc*row + col_offset;
//
// //if(threadIdx.x < 8)
// // if(accumulatorC[row*8 + threadIdx.x] != 0.0)
// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x);
//
// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M)
// {
// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx);
// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x];
// }
// }
//
//
//
//}
__device__
void
compute
(
float
*
global_out
,
float
const
*
shared_in
)
{
...
...
@@ -3122,10 +3217,8 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
// half alpha, half beta);
template
__global__
void
gemm_device
<
float
,
4
,
256
>(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
4
,
256
>(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
float
,
8
,
256
>(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
8
,
256
>(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
float
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
float
*
__restrict__
const
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
__global__
void
gemm_device
<
half
,
16
,
128
>(
int
M
,
int
N
,
int
K
,
half
*
__restrict__
const
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
...
...
csrc/kernels.cuh
View file @
f3e97ccb
...
...
@@ -138,6 +138,6 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template
<
size_t
stages_count
/* Pipeline with stages_count stages */
>
__global__
void
with_staging_unified
(
float
const
*
global_in
,
float
*
global_out
,
size_t
size
,
size_t
batch_sz
);
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
,
int
ITEMS
,
int
THREADS
>
__global__
void
gemm_device
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
#endif
csrc/ops.cu
View file @
f3e97ccb
...
...
@@ -675,10 +675,10 @@ void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
dim3
dimBlock
(
256
);
dim3
dimBlock
(
128
);
int
num_blocks
=
(
m
+
7
)
/
8
;
cout
<<
num_blocks
<<
endl
;
...
...
@@ -689,7 +689,7 @@ template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T
cout
<<
m
<<
endl
;
cout
<<
n
<<
endl
;
cout
<<
k
<<
endl
;
gemm_device
<
T
,
8
,
256
>
gemm_device
<
T
,
16
,
128
>
<<<
num_blocks
,
dimBlock
,
0
,
0
>>>
(
m
,
n
,
k
,
A
,
...
...
@@ -701,8 +701,8 @@ template <typename T> void gemm_host(int m, int n, int k, T const* A, T* B, T
// TEMPLATE DEFINITIONS
//==============================================================
template
void
gemm_host
<
float
>(
int
m
,
int
n
,
int
k
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
gemm_host
<
float
>(
int
m
,
int
n
,
int
k
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
gemm_host
<
half
>(
int
m
,
int
n
,
int
k
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
void
extractOutliers
<
COL_TURING
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
template
void
extractOutliers
<
COL_AMPERE
>(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
);
...
...
csrc/ops.cuh
View file @
f3e97ccb
...
...
@@ -190,7 +190,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
void
matmul4bite
(
half
*
A
,
unsigned
char
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
rowsA
,
int
colsA
,
int
colsB
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
const
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
template
<
typename
T
>
void
gemm_host
(
int
m
,
int
n
,
int
k
,
T
*
A
,
T
*
B
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
);
void
pipeline_test
(
float
*
A
,
float
*
B
,
size_t
n
,
size_t
batch_size
);
...
...
csrc/pythonInterface.c
View file @
f3e97ccb
...
...
@@ -20,9 +20,9 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
void
estimateQuantiles_fp16
(
half
*
A
,
float
*
code
,
float
offset
,
int
n
){
estimateQuantiles
<
half
>
(
A
,
code
,
offset
,
n
);
}
void
gemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
void
gemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
float
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
void
gemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
void
gemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host
<
half
>
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
...
...
@@ -313,10 +313,10 @@ extern "C"
void
cextractOutliers_ampere
(
char
*
A
,
int
*
idx
,
char
*
out
,
int
idx_size
,
int
rows
,
int
cols
){
extractOutliers_ampere
(
A
,
idx
,
out
,
idx_size
,
rows
,
cols
);
}
void
cpipeline_test
(
float
*
A
,
float
*
B
,
size_t
n
,
size_t
batch_size
){
pipeline_test
(
A
,
B
,
n
,
batch_size
);
}
void
cgemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
const
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
void
cgemm_host_fp32
(
int
M
,
int
N
,
int
K
,
float
*
A
,
float
*
B
,
float
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host_fp32
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
void
cgemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
const
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
void
cgemm_host_fp16
(
int
M
,
int
N
,
int
K
,
half
*
A
,
half
*
B
,
half
*
out
,
int
lda
,
int
ldb
,
int
ldc
)
{
gemm_host_fp16
(
M
,
N
,
K
,
A
,
B
,
out
,
lda
,
ldb
,
ldc
);
}
#endif
...
...
tests/test_functional.py
View file @
f3e97ccb
...
...
@@ -2355,11 +2355,11 @@ def test_normal_map_tree():
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
],
ids
=
[
'fp16'
])
def
test_cutlass3_gemm
(
dtype
):
for
i
in
range
(
2
):
A
=
torch
.
rand
(
2
,
4092
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
(
4
*
4092
,
4092
,
dtype
=
dtype
,
device
=
'cuda'
)
#
A = torch.rand(
2
, 4, dtype=dtype, device='cuda')
#
B = torch.rand(4
, 4
, dtype=dtype, device='cuda')
for
i
in
range
(
1
):
#
A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#
B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
A
=
torch
.
rand
(
1
,
4
096
,
dtype
=
dtype
,
device
=
'cuda'
)
B
=
torch
.
rand
(
4
*
4096
,
4096
,
dtype
=
dtype
,
device
=
'cuda'
)
#print('')
#print(A)
...
...
@@ -2371,7 +2371,7 @@ def test_cutlass3_gemm(dtype):
#print(C1)
#print(C2)
#
torch.testing.assert_close(C1, C2)
torch
.
testing
.
assert_close
(
C1
,
C2
,
atol
=
1e-05
,
rtol
=
0.005
)
def
test_pipeline_func
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment