Commit 35e5c532 authored by aska-0096's avatar aska-0096
Browse files

clang format

parent b010b095
...@@ -252,7 +252,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -252,7 +252,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec, B1Spec,
CSpec>; CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
...@@ -260,7 +261,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -260,7 +261,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<AK1>{}); Number<AK1>{});
} }
static auto MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
...@@ -268,8 +270,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -268,8 +270,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
Number<BK1>{}); Number<BK1>{});
} }
static auto static auto MakeB1GridDescriptor_BK0_N_BK1(
MakeB1GridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
...@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::array<index_t, NumDimG + NumDimM + NumDimN>&
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::array<index_t, NumDimG + NumDimM + NumDimN>&
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>& const std::array<std::vector<ck::index_t>, NumD1Tensor>&
...@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin()); std::copy(a_gs_ms_ks_lengths.begin(),
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin()); a_gs_ms_ks_lengths.begin() + dimension,
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin()); a_gs_ms_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin()); std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension, b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension, b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(), std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension, c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(), std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension, c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return Argument{p_a, return Argument{p_a,
...@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin()); std::copy(a_gs_ms_ks_lengths.begin(),
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin()); a_gs_ms_ks_lengths.begin() + dimension,
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin()); a_gs_ms_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin()); std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension, b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension, b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(), std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension, c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(), std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension, c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment