Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cf9bcb31
Commit
cf9bcb31
authored
Jun 11, 2023
by
root
Browse files
minimize arg size
parent
09cc45d3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
278 additions
and
7 deletions
+278
-7
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
+5
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+272
-3
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
View file @
cf9bcb31
...
@@ -54,7 +54,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
...
@@ -54,7 +54,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSpl
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
//< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
>
;
// clang-format on
// clang-format on
#include "run_grouped_gemm_example.inc"
#include "run_grouped_gemm_example.inc"
...
@@ -66,8 +68,8 @@ int main(int argc, char* argv[])
...
@@ -66,8 +68,8 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
problem_size
.
Ms
=
{
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
,
2
};
167
,
1
83
,
1
77
,
18
1
,
1
53
,
1
39
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
1
48
};
// problem_size.Ms = {2
, 1, 1, 1, 1, 1
, 3, 4, 3, 5, 2, 4, 2, 1, 0
, 1};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
cf9bcb31
...
@@ -23,8 +23,10 @@ namespace ck {
...
@@ -23,8 +23,10 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#if 1
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
typename
GemmSharedArgs
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
__global__
void
...
@@ -32,7 +34,7 @@ __global__ void
...
@@ -32,7 +34,7 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
)
const
GemmSharedArgs
gemm_shared_args
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
defined(__gfx940__))
...
@@ -43,6 +45,87 @@ __global__ void
...
@@ -43,6 +45,87 @@ __global__ void
const
auto
gemm_desc_ptr
=
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
const
index_t
group_id
=
block_id
/
gemm_shared_args
.
block_size
;
#if 1
// const auto M = gemm_shared_args.M;
// const auto N = gemm_shared_args.N;
// const auto K = gemm_shared_args.K;
// const auto StrideA = gemm_shared_args.StrideA;
// const auto StrideB = gemm_shared_args.StrideB;
// const auto StrideC = gemm_shared_args.StrideC;
// const auto MPadded = gemm_shared_args.MPadded;
// const auto NPadded = gemm_shared_args.NPadded;
// const auto KPadded = gemm_shared_args.KPadded;
// const auto K0 = gemm_shared_args.KPadded;
// const auto k_batch = gemm_shared_args.k_batch;
// M = 2 N = 768 K = 4608 StrideA = 4608 StrideB = 4608 StrideC = 768 MPadded = 32 NPadded = 768
// KPadded = 4608 K0 = 576 k_batch = 1
const
auto
M
=
2
;
const
auto
N
=
768
;
const
auto
K
=
4608
;
const
auto
StrideA
=
4608
;
const
auto
StrideB
=
4608
;
const
auto
StrideC
=
768
;
const
auto
MPadded
=
32
;
const
auto
NPadded
=
768
;
const
auto
KPadded
=
4608
;
const
auto
K0
=
576
;
const
auto
k_batch
=
1
;
// const auto block_2_ctile_map = gemm_shared_args.block_2_ctile_map;
#endif
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_b_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
karg_
.
block_2_ctile_map
);
#else
ignore
=
gemm_descs_const
;
ignore
=
all_gemm_block_size
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
#else
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
// const index_t group_count
const
index_t
block_size
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
#if 0
index_t left = 0;
index_t left = 0;
index_t right = group_count;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
index_t group_id = index_t((left + right) / 2);
...
@@ -60,16 +143,51 @@ __global__ void
...
@@ -60,16 +143,51 @@ __global__ void
}
}
group_id = index_t((left + right) / 2);
group_id = index_t((left + right) / 2);
}
}
#else
const
index_t
group_id
=
block_id
/
block_size
;
#endif
#if 0
const auto N = gemm_desc_ptr[0].karg_.N;
const auto K = gemm_desc_ptr[0].karg_.K;
const auto StrideB = gemm_desc_ptr[0].karg_.StrideB;
const auto NPadded = gemm_desc_ptr[0].karg_.NPadded;
const auto KPadded = gemm_desc_ptr[0].karg_.KPadded;
const auto K0 = gemm_desc_ptr[0].karg_.KPadded;
const auto k_batch = gemm_desc_ptr[0].karg_.k_batch;
const auto block_2_ctile_map = gemm_desc_ptr[0].block_2_ctile_map_;
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_.p_a_grid,
gemm_desc_ptr[group_id].karg_.p_b_grid,
gemm_desc_ptr[group_id].karg_.p_c_grid,
gemm_desc_ptr[group_id].karg_.M,
N,
K,
gemm_desc_ptr[group_id].karg_.StrideA,
StrideB,
gemm_desc_ptr[group_id].karg_.StrideC,
gemm_desc_ptr[group_id].karg_.MPadded,
NPadded,
KPadded,
K0,
k_batch,
static_cast<void*>(p_shared),
block_2_ctile_map);
#else
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#endif
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
group_count
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
#endif
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -406,10 +524,104 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -406,10 +524,104 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
}
}
struct
ArgumentMsN1K1
{
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
EDataType
*
p_c_grid
;
// index_t M;
// index_t StrideA;
// index_t StrideC;
// index_t MPadded;
GroupedGemmBlock2ETileMap
block_2_ctile_map
;
};
struct
GemmTransKernelArgMsN1K1
{
ArgumentMsN1K1
karg_
;
};
#if 1
std
::
vector
<
GemmTransKernelArgMsN1K1
>
gemm_kernel_args_msn1k1_
;
index_t
all_gemm_block_size
=
arg
.
gemm_kernel_args_
[
0
].
block_end_
-
arg
.
gemm_kernel_args_
[
0
].
block_start_
;
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
auto
karg
=
ArgumentMsN1K1
{
trans_arg
.
karg_
.
p_a_grid
,
trans_arg
.
karg_
.
p_b_grid
,
trans_arg
.
karg_
.
p_c_grid
,
trans_arg
.
block_2_ctile_map_
};
// auto block_size = trans_arg.block_end_ - trans_arg.block_start_;
// std::cout << "trans_arg.block_start_: " << trans_arg.block_start_
// << " trans_arg.block_end_: " << trans_arg.block_end_
// << " block_size: " << block_size << std::endl;
gemm_kernel_args_msn1k1_
.
push_back
({
karg
});
}
#endif
#if 0
hip_check_error(hipMemcpy(arg.p_workspace_,
hip_check_error(hipMemcpy(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice));
hipMemcpyHostToDevice));
#else
struct
GemmSharedArgs
{
index_t
block_size
;
// index_t M;
// index_t N;
// index_t K;
// index_t StrideA;
// index_t StrideB;
// index_t StrideC;
// index_t MPadded;
// index_t NPadded;
// index_t KPadded;
// index_t K0;
// index_t k_batch;
// GroupedGemmBlock2ETileMap block_2_ctile_map;
#if 0
void print()
{
std::cout << "block_size = " << block_size << " M = " << M << " N = " << N
<< " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC
<< " MPadded = " << MPadded << " NPadded = " << NPadded
<< " KPadded = " << KPadded << " K0 = " << K0
<< " k_batch = " << k_batch << std::endl;
}
#endif
};
auto
shared_karg
=
GemmSharedArgs
{
all_gemm_block_size
,
// arg.gemm_kernel_args_[0].karg_.M,
// arg.gemm_kernel_args_[0].karg_.N,
// arg.gemm_kernel_args_[0].karg_.K,
// arg.gemm_kernel_args_[0].karg_.StrideA,
// arg.gemm_kernel_args_[0].karg_.StrideB,
// arg.gemm_kernel_args_[0].karg_.StrideC,
// arg.gemm_kernel_args_[0].karg_.MPadded,
// arg.gemm_kernel_args_[0].karg_.NPadded,
// arg.gemm_kernel_args_[0].karg_.KPadded,
// arg.gemm_kernel_args_[0].karg_.K0,
// arg.gemm_kernel_args_[0].karg_.k_batch,
// arg.gemm_kernel_args_[0].block_2_ctile_map_,
};
// shared_karg.print();
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
gemm_kernel_args_msn1k1_
.
data
(),
gemm_kernel_args_msn1k1_
.
size
()
*
sizeof
(
GemmTransKernelArgMsN1K1
),
hipMemcpyHostToDevice
));
#endif
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -431,9 +643,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -431,9 +643,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
());
shared_karg
// all_gemm_block_size
);
};
};
std
::
cout
<<
"all_have_main_k0_block_loop: "
<<
all_have_main_k0_block_loop
<<
" all_have_kbatch_gt_one: "
<<
all_have_kbatch_gt_one
<<
std
::
endl
;
#if 0
if(all_have_main_k0_block_loop)
if(all_have_main_k0_block_loop)
{
{
if(all_have_kbatch_gt_one)
if(all_have_kbatch_gt_one)
...
@@ -480,6 +697,58 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -480,6 +697,58 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Run(kernel);
Run(kernel);
}
}
}
}
#else
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArgMsN1K1
,
GemmSharedArgs
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
#endif
return
ave_time
;
return
ave_time
;
}
}
...
@@ -614,7 +883,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -614,7 +883,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
}
};
};
// namespace device
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
script/cmake-ck-dev.sh
View file @
cf9bcb31
...
@@ -8,7 +8,7 @@ MY_PROJECT_SOURCE=$1
...
@@ -8,7 +8,7 @@ MY_PROJECT_SOURCE=$1
cmake
\
cmake
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_PREFIX_PATH
=
/opt/rocm
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++
17
-O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++
20
-O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker
\
-save-temps=
$PWD
"
\
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment