Commit 35e5c532 authored by aska-0096's avatar aska-0096
Browse files

clang format

parent b010b095
...@@ -252,25 +252,27 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -252,25 +252,27 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
B1Spec, B1Spec,
CSpec>; CSpec>;
static auto MakeAGridDescriptor_AK0_M_AK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides_vec)
{ {
return Transform::MakeAGridDescriptor_AK0_M_AK1( return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec),
Number<AK1>{}); Number<AK1>{});
} }
static auto MakeBGridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec, static auto MakeBGridDescriptor_BK0_N_BK1(
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides_vec)
{ {
return Transform::MakeB0GridDescriptor_BK0_N_BK1( return Transform::MakeB0GridDescriptor_BK0_N_BK1(
Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec),
Number<BK1>{}); Number<BK1>{});
} }
static auto static auto MakeB1GridDescriptor_BK0_N_BK1(
MakeB1GridDescriptor_BK0_N_BK1(const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths_vec,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec) const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides_vec)
{ {
return Transform::MakeB1GridDescriptor_BK0_N_BK1( return Transform::MakeB1GridDescriptor_BK0_N_BK1(
Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec,
...@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -457,10 +459,14 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& a_gs_ms_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_lengths,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides, const std::array<index_t, NumDimG + NumDimM + NumDimN>& b_gs_ns_ks_strides,
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths const std::array<index_t, NumDimG + NumDimM + NumDimN>&
const std::array<index_t, NumDimG + NumDimM + NumDimN>& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::array<index_t, NumDimG + NumDimM + NumDimN>&
const std::array<index_t, NumDimG + NumDimM + NumDimN>& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths
const std::array<index_t, NumDimG + NumDimM + NumDimN>&
c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumD0Tensor>& acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumD1Tensor>& const std::array<std::vector<ck::index_t>, NumD1Tensor>&
...@@ -836,7 +842,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -836,7 +842,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
C1DEElementwiseOperation c1de_element_op) C1DEElementwiseOperation c1de_element_op)
{ {
constexpr auto dimension = NumDimG + NumDimM + NumDimN; constexpr auto dimension = NumDimG + NumDimM + NumDimN;
std::array<index_t, dimension> a_gs_ms_ks_lengths_{}; std::array<index_t, dimension> a_gs_ms_ks_lengths_{};
std::array<index_t, dimension> a_gs_ms_ks_strides_{}; std::array<index_t, dimension> a_gs_ms_ks_strides_{};
std::array<index_t, dimension> b_gs_ns_ks_lengths_{}; std::array<index_t, dimension> b_gs_ns_ks_lengths_{};
...@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -846,21 +852,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin()); std::copy(a_gs_ms_ks_lengths.begin(),
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin()); a_gs_ms_ks_lengths.begin() + dimension,
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin()); a_gs_ms_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin()); std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension, b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension, b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(), std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension, c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(), std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension, c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return Argument{p_a, return Argument{p_a,
...@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -930,21 +944,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths std::array<index_t, dimension> c_gs_ms_gemm1ns_lengths_{}; // c_gs_ms_os_lengths
std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides std::array<index_t, dimension> c_gs_ms_gemm1ns_strides_{}; // c_gs_ms_os_strides
std::copy(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.begin()+dimension, a_gs_ms_ks_lengths_.begin()); std::copy(a_gs_ms_ks_lengths.begin(),
std::copy(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.begin()+dimension, a_gs_ms_ks_strides_.begin()); a_gs_ms_ks_lengths.begin() + dimension,
std::copy(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.begin()+dimension, b_gs_ns_ks_lengths_.begin()); a_gs_ms_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.begin()+dimension, b_gs_ns_ks_strides_.begin()); std::copy(a_gs_ms_ks_strides.begin(),
a_gs_ms_ks_strides.begin() + dimension,
a_gs_ms_ks_strides_.begin());
std::copy(b_gs_ns_ks_lengths.begin(),
b_gs_ns_ks_lengths.begin() + dimension,
b_gs_ns_ks_lengths_.begin());
std::copy(b_gs_ns_ks_strides.begin(),
b_gs_ns_ks_strides.begin() + dimension,
b_gs_ns_ks_strides_.begin());
std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_lengths.begin(),
b1_gs_gemm1ns_gemm1ks_lengths.begin()+dimension, b1_gs_gemm1ns_gemm1ks_lengths.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths b1_gs_gemm1ns_gemm1ks_lengths_.begin()); // b1_gs_os_ns_lengths
std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(), std::copy(b1_gs_gemm1ns_gemm1ks_strides.begin(),
b1_gs_gemm1ns_gemm1ks_strides.begin()+dimension, b1_gs_gemm1ns_gemm1ks_strides.begin() + dimension,
b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides b1_gs_gemm1ns_gemm1ks_strides_.begin()); // b1_gs_os_ns_strides
std::copy(c_gs_ms_gemm1ns_lengths.begin(), std::copy(c_gs_ms_gemm1ns_lengths.begin(),
c_gs_ms_gemm1ns_lengths.begin()+dimension, c_gs_ms_gemm1ns_lengths.begin() + dimension,
c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths c_gs_ms_gemm1ns_lengths_.begin()); // c_gs_ms_os_lengths
std::copy(c_gs_ms_gemm1ns_strides.begin(), std::copy(c_gs_ms_gemm1ns_strides.begin(),
c_gs_ms_gemm1ns_strides.begin()+dimension, c_gs_ms_gemm1ns_strides.begin() + dimension,
c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides c_gs_ms_gemm1ns_strides_.begin()); // c_gs_ms_os_strides
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment