Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
0591da5f
Unverified
Commit
0591da5f
authored
Jan 03, 2023
by
Yifan Xiong
Committed by
GitHub
Jan 03, 2023
Browse files
Benchmarks - Add cuBLASLt FP16 and FP8 GEMM micro-benchmark (#451)
Add micro-benchmark for cublaslt fp8 gemm.
parent
678b1251
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
363 additions
and
0 deletions
+363
-0
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/CMakeLists.txt
...chmarks/micro_benchmarks/cublaslt_fp8_gemm/CMakeLists.txt
+22
-0
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_fp8_gemm.cu
...s/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_fp8_gemm.cu
+153
-0
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_utils.cc
...arks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_utils.cc
+122
-0
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_utils.h
...marks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_utils.h
+63
-0
superbench/benchmarks/micro_benchmarks/cuda_common.cmake
superbench/benchmarks/micro_benchmarks/cuda_common.cmake
+3
-0
No files found.
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/CMakeLists.txt
0 → 100644
View file @
0591da5f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
cmake_minimum_required
(
VERSION 3.18
)
project
(
cublaslt_fp8_gemm LANGUAGES CXX
)
find_package
(
CUDAToolkit QUIET
)
if
(
CUDAToolkit_FOUND AND NOT CUDAToolkit_VERSION VERSION_LESS 11.8
)
set
(
CMAKE_CUDA_STANDARD 17
)
include
(
../cuda_common.cmake
)
add_library
(
cublaslt_utils SHARED cublaslt_utils.cc
)
target_link_libraries
(
cublaslt_utils CUDA::cublas CUDA::cublasLt
)
set_target_properties
(
cublaslt_utils PROPERTIES LINK_FLAGS_RELEASE -s
)
install
(
TARGETS cublaslt_utils LIBRARY DESTINATION lib
)
add_executable
(
cublaslt_fp8_gemm cublaslt_fp8_gemm.cu
)
target_link_libraries
(
cublaslt_fp8_gemm cublaslt_utils
)
set_target_properties
(
cublaslt_fp8_gemm PROPERTIES CUDA_ARCHITECTURES
"80;86;90"
)
install
(
TARGETS cublaslt_fp8_gemm RUNTIME DESTINATION bin
)
endif
()
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_fp8_gemm.cu
0 → 100644
View file @
0591da5f
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.
#include <getopt.h>
#include <memory>
#include <stdio.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include "cublaslt_utils.h"
using
fp16
=
half
;
// nv_bfloat16
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
struct
Args
{
int
m
=
16
;
int
n
=
16
;
int
k
=
16
;
int
batch
=
0
;
int
warmup
=
20
;
int
iter
=
50
;
std
::
string
in_type
=
"fp8e4m3"
;
};
void
process_args
(
int
argc
,
char
**
argv
,
Args
*
args
)
{
const
char
*
const
short_opts
=
"m:n:k:b:w:i:t:"
;
const
option
long_opts
[]
=
{
{
"batch"
,
required_argument
,
nullptr
,
'b'
},
{
"warmup"
,
required_argument
,
nullptr
,
'w'
},
{
"iter"
,
required_argument
,
nullptr
,
'i'
},
{
"in_type"
,
required_argument
,
nullptr
,
't'
},
};
int
opt
=
0
;
while
((
opt
=
getopt_long
(
argc
,
argv
,
short_opts
,
long_opts
,
nullptr
))
!=
-
1
)
{
switch
(
opt
)
{
case
'm'
:
args
->
m
=
std
::
stoi
(
optarg
);
break
;
case
'n'
:
args
->
n
=
std
::
stoi
(
optarg
);
break
;
case
'k'
:
args
->
k
=
std
::
stoi
(
optarg
);
break
;
case
'b'
:
args
->
batch
=
std
::
stoi
(
optarg
);
break
;
case
'w'
:
args
->
warmup
=
std
::
stoi
(
optarg
);
break
;
case
'i'
:
args
->
iter
=
std
::
stoi
(
optarg
);
break
;
case
't'
:
args
->
in_type
=
std
::
string
(
optarg
);
break
;
}
}
}
template
<
typename
T
>
__global__
void
init_matrix
(
T
*
matrix
,
const
fp16
val
,
const
size_t
N
)
{
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
size_t
i
=
tid
;
i
<
N
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
matrix
[
i
]
=
T
(
val
);
}
}
template
<
typename
T
>
cudaDataType_t
get_datatype
()
{
if
(
std
::
is_same
<
T
,
fp16
>::
value
)
return
CUDA_R_16F
;
if
(
std
::
is_same
<
T
,
fp8e4m3
>::
value
)
return
CUDA_R_8F_E4M3
;
if
(
std
::
is_same
<
T
,
fp8e5m2
>::
value
)
return
CUDA_R_8F_E5M2
;
throw
std
::
invalid_argument
(
"Unknown type"
);
}
template
<
typename
Ta
,
typename
Tb
,
typename
Tout
>
float
timing_matmul_tn
(
int
m
,
int
n
,
int
k
,
int
batch
,
int
warmup
,
int
iter
)
{
// init matrix
Ta
*
matrix_a
=
nullptr
;
Tb
*
matrix_b
=
nullptr
;
Tout
*
matrix_out
=
nullptr
;
cudaMalloc
(
&
matrix_a
,
m
*
k
*
std
::
max
(
batch
,
1
)
*
sizeof
(
Ta
));
cudaMalloc
(
&
matrix_b
,
k
*
n
*
std
::
max
(
batch
,
1
)
*
sizeof
(
Tb
));
cudaMalloc
(
&
matrix_out
,
m
*
n
*
std
::
max
(
batch
,
1
)
*
sizeof
(
Tout
));
init_matrix
<
Ta
><<<
216
,
1024
>>>
(
matrix_a
,
static_cast
<
fp16
>
(
1.
f
),
m
*
k
*
std
::
max
(
batch
,
1
));
init_matrix
<
Tb
><<<
216
,
1024
>>>
(
matrix_b
,
static_cast
<
fp16
>
(
2.
f
),
k
*
n
*
std
::
max
(
batch
,
1
));
// init gemm
int
lda
=
k
,
ldb
=
k
,
ldd
=
m
;
std
::
unique_ptr
<
cublasLtGemm
>
gemm
=
std
::
make_unique
<
cublasLtGemm
>
();
gemm
->
Init
();
gemm
->
Setup
(
m
,
n
,
k
,
batch
,
lda
,
ldb
,
ldd
,
get_datatype
<
Ta
>
(),
get_datatype
<
Tb
>
(),
get_datatype
<
Tout
>
(),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLASLT_EPILOGUE_DEFAULT
);
void
*
workspace
=
nullptr
;
size_t
workspace_size
=
gemm
->
GetAlgorithm
(
1
,
2
*
m
*
n
);
cudaMalloc
(
&
workspace
,
workspace_size
);
// timer
float
time
;
cudaEvent_t
startTime
,
endTime
;
cudaEventCreate
(
&
startTime
);
cudaEventCreate
(
&
endTime
);
for
(
int
i
=
0
;
i
<
warmup
;
i
++
)
gemm
->
Execute
(
reinterpret_cast
<
void
*>
(
matrix_a
),
reinterpret_cast
<
void
*>
(
matrix_b
),
reinterpret_cast
<
void
*>
(
matrix_out
),
reinterpret_cast
<
void
*>
(
matrix_out
),
1.
f
,
0.
f
,
workspace
,
workspace_size
,
0
);
cudaEventRecord
(
startTime
,
0
);
for
(
int
i
=
0
;
i
<
iter
;
i
++
)
gemm
->
Execute
(
reinterpret_cast
<
void
*>
(
matrix_a
),
reinterpret_cast
<
void
*>
(
matrix_b
),
reinterpret_cast
<
void
*>
(
matrix_out
),
reinterpret_cast
<
void
*>
(
matrix_out
),
1.
f
,
0.
f
,
workspace
,
workspace_size
,
0
);
cudaEventRecord
(
endTime
,
0
);
cudaEventSynchronize
(
endTime
);
cudaEventElapsedTime
(
&
time
,
startTime
,
endTime
);
// deallocate
cudaFree
(
workspace
);
cudaFree
(
matrix_a
);
cudaFree
(
matrix_b
);
cudaFree
(
matrix_out
);
return
(
time
*
1e3
/
iter
);
}
template
<
typename
Ta
,
typename
Tb
=
Ta
,
typename
Tout
=
fp16
>
void
run
(
Args
*
args
)
{
float
time_us
=
timing_matmul_tn
<
Ta
,
Tb
,
Tout
>
(
args
->
m
,
args
->
n
,
args
->
k
,
args
->
batch
,
args
->
warmup
,
args
->
iter
);
// m n k batch time_us tflops
printf
(
"%d
\t
%d
\t
%d
\t
%d
\t
%f
\t
%f
\n
"
,
args
->
m
,
args
->
n
,
args
->
k
,
args
->
batch
,
time_us
,
float
(
args
->
m
)
*
float
(
args
->
n
)
*
float
(
2
*
args
->
k
-
1
)
/
1e6
/
time_us
*
std
::
max
(
args
->
batch
,
1
));
}
int
main
(
int
argc
,
char
**
argv
)
{
Args
args
;
process_args
(
argc
,
argv
,
&
args
);
if
(
args
.
in_type
==
"fp16"
)
run
<
fp16
>
(
&
args
);
else
if
(
args
.
in_type
==
"fp8e4m3"
)
run
<
fp8e4m3
>
(
&
args
);
else
if
(
args
.
in_type
==
"fp8e5m2"
)
run
<
fp8e5m2
,
fp8e4m3
>
(
&
args
);
else
throw
std
::
invalid_argument
(
"Unknown type "
+
args
.
in_type
);
return
0
;
}
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_utils.cc
0 → 100644
View file @
0591da5f
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.
#include "cublaslt_utils.h"
void
cublasLtGemm
::
Init
()
{
cublasLtHandle_t
handle
;
checkCublasStatus
(
cublasLtCreate
(
&
handle
));
handle_
.
reset
(
handle
);
/* preference can be initialized without arguments */
cublasLtMatmulPreference_t
preference
;
checkCublasStatus
(
cublasLtMatmulPreferenceCreate
(
&
preference
));
preference_
.
reset
(
preference
);
}
void
cublasLtGemm
::
Setup
(
int
m
,
int
n
,
int
k
,
int
batch
,
int
lda
,
int
ldb
,
int
ldd
,
cudaDataType_t
a_type
,
cudaDataType_t
b_type
,
cudaDataType_t
d_type
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
cublasLtEpilogue_t
epilogue
,
void
*
a_scale_inverse
,
/* only need to be set for fp8 */
void
*
b_scale_inverse
/* only need to be set for fp8 */
)
{
cublasLtMatrixLayout_t
a_desc
=
nullptr
,
b_desc
=
nullptr
,
c_desc
=
nullptr
,
d_desc
=
nullptr
;
// force c_type
cudaDataType_t
c_type
=
CUDA_R_16F
;
// Create matrix descriptors.
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
a_desc
,
a_type
,
transa
==
CUBLAS_OP_N
?
m
:
k
,
transa
==
CUBLAS_OP_N
?
k
:
m
,
lda
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
b_desc
,
b_type
,
transb
==
CUBLAS_OP_N
?
k
:
n
,
transb
==
CUBLAS_OP_N
?
n
:
k
,
ldb
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
c_desc
,
c_type
,
m
,
n
,
ldd
));
checkCublasStatus
(
cublasLtMatrixLayoutCreate
(
&
d_desc
,
d_type
,
m
,
n
,
ldd
));
// strided batch gemm
if
(
batch
>
0
)
{
int64_t
stridea
=
m
*
k
,
strideb
=
k
*
n
,
stridec
=
m
*
n
,
strided
=
m
*
n
;
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
a_desc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch
,
sizeof
(
batch
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
a_desc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
stridea
,
sizeof
(
stridea
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
b_desc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch
,
sizeof
(
batch
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
b_desc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strideb
,
sizeof
(
strideb
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
c_desc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch
,
sizeof
(
batch
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
c_desc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
stridec
,
sizeof
(
stridec
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
d_desc
,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT
,
&
batch
,
sizeof
(
batch
)));
checkCublasStatus
(
cublasLtMatrixLayoutSetAttribute
(
d_desc
,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET
,
&
strided
,
sizeof
(
strided
)));
}
a_desc_
.
reset
(
a_desc
);
b_desc_
.
reset
(
b_desc
);
c_desc_
.
reset
(
c_desc
);
d_desc_
.
reset
(
d_desc
);
// default to tf32 except for e5m2 inputs where the config is not supported
cublasComputeType_t
gemm_compute_type
=
(
a_type
==
CUDA_R_8F_E5M2
||
b_type
==
CUDA_R_8F_E5M2
||
a_type
==
CUDA_R_8F_E4M3
||
b_type
==
CUDA_R_8F_E4M3
)
?
CUBLAS_COMPUTE_32F
:
CUBLAS_COMPUTE_32F_FAST_TF32
;
cublasLtMatmulDesc_t
op_desc
=
nullptr
;
checkCublasStatus
(
cublasLtMatmulDescCreate
(
&
op_desc
,
gemm_compute_type
,
CUDA_R_32F
));
op_desc_
.
reset
(
op_desc
);
if
(
a_type
==
CUDA_R_8F_E5M2
||
b_type
==
CUDA_R_8F_E5M2
||
a_type
==
CUDA_R_8F_E4M3
||
b_type
==
CUDA_R_8F_E4M3
)
{
// disable fastAccuMode, set to 0
int8_t
fastAccuMode
=
1
;
cublasLtMatmulDescSetAttribute
(
op_desc
,
CUBLASLT_MATMUL_DESC_FAST_ACCUM
,
&
fastAccuMode
,
sizeof
(
fastAccuMode
));
}
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
op_desc_
.
get
(),
CUBLASLT_MATMUL_DESC_TRANSA
,
&
transa
,
sizeof
(
transa
)));
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
op_desc_
.
get
(),
CUBLASLT_MATMUL_DESC_TRANSB
,
&
transb
,
sizeof
(
transb
)));
if
(
a_scale_inverse
!=
nullptr
)
{
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
op_desc_
.
get
(),
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER
,
&
a_scale_inverse
,
sizeof
(
a_scale_inverse
)));
}
if
(
b_scale_inverse
!=
nullptr
)
{
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
op_desc_
.
get
(),
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER
,
&
b_scale_inverse
,
sizeof
(
b_scale_inverse
)));
}
checkCublasStatus
(
cublasLtMatmulDescSetAttribute
(
op_desc_
.
get
(),
CUBLASLT_MATMUL_DESC_EPILOGUE
,
&
epilogue
,
sizeof
(
epilogue
)));
}
size_t
cublasLtGemm
::
GetAlgorithm
(
int
max_algorithm_count
,
size_t
max_workspace_size
)
{
checkCublasStatus
(
cublasLtMatmulPreferenceSetAttribute
(
preference_
.
get
(),
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES
,
&
max_workspace_size
,
sizeof
(
max_workspace_size
)));
int
found_algorithm_count
=
0
;
std
::
vector
<
cublasLtMatmulHeuristicResult_t
>
results
(
max_algorithm_count
);
// Though we query all of possible algorithm, we will use the first later
checkCublasStatus
(
cublasLtMatmulAlgoGetHeuristic
(
handle_
.
get
(),
op_desc_
.
get
(),
a_desc_
.
get
(),
b_desc_
.
get
(),
c_desc_
.
get
(),
d_desc_
.
get
(),
preference_
.
get
(),
max_algorithm_count
,
results
.
data
(),
&
found_algorithm_count
));
if
(
found_algorithm_count
==
0
)
{
throw
std
::
runtime_error
(
"Unable to find any suitable algorithms"
);
}
results
.
resize
(
found_algorithm_count
);
heuristic_results_
=
std
::
move
(
results
);
return
heuristic_results_
.
front
().
workspaceSize
;
}
void
cublasLtGemm
::
Execute
(
void
*
matrix_a
,
void
*
matrix_b
,
void
*
matrix_c
,
void
*
matrix_d
,
float
alpha
,
float
beta
,
void
*
workspace
,
size_t
workspace_size
,
cudaStream_t
stream
)
{
checkCublasStatus
(
cublasLtMatmul
(
handle_
.
get
(),
op_desc_
.
get
(),
static_cast
<
const
void
*>
(
&
alpha
),
/* alpha */
matrix_a
,
/* A */
a_desc_
.
get
(),
matrix_b
,
/* B */
b_desc_
.
get
(),
static_cast
<
const
void
*>
(
&
beta
),
/* beta */
matrix_c
,
/* C */
c_desc_
.
get
(),
matrix_d
,
/* D */
d_desc_
.
get
(),
&
heuristic_results_
.
front
().
algo
,
/* algo */
workspace
,
/* workspace */
workspace_size
,
stream
));
/* stream */
}
superbench/benchmarks/micro_benchmarks/cublaslt_fp8_gemm/cublaslt_utils.h
0 → 100644
View file @
0591da5f
// Copyright(c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <stdio.h>
#include <vector>
#include <cublasLt.h>
inline
void
checkCublasStatus
(
cublasStatus_t
status
)
{
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
{
printf
(
"cuBLAS API failed with status %s
\n
"
,
cublasGetStatusString
(
status
));
throw
std
::
logic_error
(
"cuBLAS API failed"
);
}
}
class
cublasLtGemm
{
public:
struct
HandleDestroyer
{
void
operator
()(
cublasLtHandle_t
handle
)
const
{
cublasLtDestroy
(
handle
);
}
};
struct
MatmulDescDestroyer
{
void
operator
()(
cublasLtMatmulDesc_t
matmul_desc
)
const
{
cublasLtMatmulDescDestroy
(
matmul_desc
);
}
};
struct
LayoutDestroyer
{
void
operator
()(
cublasLtMatrixLayout_t
layout
)
const
{
cublasLtMatrixLayoutDestroy
(
layout
);
}
};
struct
MatmulPreferenceDestroyer
{
void
operator
()(
cublasLtMatmulPreference_t
matmul_pref
)
const
{
cublasLtMatmulPreferenceDestroy
(
matmul_pref
);
}
};
using
UniqueHandle
=
std
::
unique_ptr
<
std
::
remove_pointer
<
cublasLtHandle_t
>::
type
,
HandleDestroyer
>
;
using
UniqueOpDesc
=
std
::
unique_ptr
<
std
::
remove_pointer
<
cublasLtMatmulDesc_t
>::
type
,
MatmulDescDestroyer
>
;
using
UniqueLayoutDesc
=
std
::
unique_ptr
<
std
::
remove_pointer
<
cublasLtMatrixLayout_t
>::
type
,
LayoutDestroyer
>
;
using
UniqueMatmulPreference
=
std
::
unique_ptr
<
std
::
remove_pointer
<
cublasLtMatmulPreference_t
>::
type
,
MatmulPreferenceDestroyer
>
;
void
Init
();
void
Setup
(
int
m
,
int
n
,
int
k
,
int
batch
,
int
lda
,
int
ldb
,
int
ldd
,
cudaDataType_t
a_type
,
cudaDataType_t
b_type
,
cudaDataType_t
d_type
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
cublasLtEpilogue_t
epilogue
,
void
*
a_scale_inverse
=
nullptr
,
void
*
b_scale_inverse
=
nullptr
);
size_t
GetAlgorithm
(
int
max_algorithm_count
,
size_t
max_workspace_size
);
void
Execute
(
void
*
matrix_a
,
void
*
matrix_b
,
void
*
matrix_c
,
void
*
matrix_d
,
float
alpha
,
float
beta
,
void
*
workspace
,
size_t
workspace_size
,
cudaStream_t
stream
);
private:
UniqueHandle
handle_
;
UniqueOpDesc
op_desc_
;
UniqueLayoutDesc
a_desc_
;
UniqueLayoutDesc
b_desc_
;
UniqueLayoutDesc
c_desc_
;
UniqueLayoutDesc
d_desc_
;
UniqueMatmulPreference
preference_
;
std
::
vector
<
cublasLtMatmulHeuristicResult_t
>
heuristic_results_
;
};
superbench/benchmarks/micro_benchmarks/cuda_common.cmake
View file @
0591da5f
...
@@ -32,4 +32,7 @@ if(NOT DEFINED NVCC_ARCHS_SUPPORTED)
...
@@ -32,4 +32,7 @@ if(NOT DEFINED NVCC_ARCHS_SUPPORTED)
if
(
NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.1
)
if
(
NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.1
)
list
(
APPEND NVCC_ARCHS_SUPPORTED 86
)
list
(
APPEND NVCC_ARCHS_SUPPORTED 86
)
endif
()
endif
()
if
(
NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.8
)
list
(
APPEND NVCC_ARCHS_SUPPORTED 90
)
endif
()
endif
()
endif
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment