"docs/source/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "62ec38ea4148bb8147f346f7e01cab6b8a2ec7b6"
Commit 289f15de authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into wmma_gemm

parents 9bd44685 d58b7f51
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
...@@ -288,21 +289,11 @@ int main(int argc, char* argv[]) ...@@ -288,21 +289,11 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
Tensor<ADataType> a_ms_ks( Tensor<ADataType> a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides);
std::vector<std::size_t>(a_ms_ks_lengths.begin(), a_ms_ks_lengths.end()), Tensor<BDataType> b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides);
std::vector<std::size_t>(a_ms_ks_strides.begin(), a_ms_ks_strides.end())); Tensor<EDataType> d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<BDataType> b_ns_ks( Tensor<EDataType> e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(b_ns_ks_lengths.begin(), b_ns_ks_lengths.end()), Tensor<EDataType> e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(b_ns_ks_strides.begin(), b_ns_ks_strides.end()));
Tensor<EDataType> d_ms_ns(
std::vector<std::size_t>(d_ms_ns_lengths.begin(), d_ms_ns_lengths.end()),
std::vector<std::size_t>(d_ms_ns_strides.begin(), d_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_host_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
Tensor<EDataType> e_ms_ns_device_result(
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl;
std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl;
...@@ -368,20 +359,14 @@ int main(int argc, char* argv[]) ...@@ -368,20 +359,14 @@ int main(int argc, char* argv[])
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
ck::index_t M = std::accumulate(e_ms_ns_lengths.begin(), ck::index_t M =
e_ms_ns_lengths.begin() + NumDimM, ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t N = std::accumulate(e_ms_ns_lengths.begin() + NumDimM, ck::index_t N = ck::accumulate_n<ck::index_t>(
e_ms_ns_lengths.begin() + NumDimM + NumDimN, e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
ck::index_t K = std::accumulate(a_ms_ks_lengths.begin() + NumDimM, ck::index_t K = ck::accumulate_n<ck::index_t>(
a_ms_ks_lengths.begin() + NumDimM + NumDimK, a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
ck::index_t{1},
std::multiplies<ck::index_t>{});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
...@@ -398,9 +383,7 @@ int main(int argc, char* argv[]) ...@@ -398,9 +383,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
Tensor<CShuffleDataType> c_ms_ns_host_result( Tensor<CShuffleDataType> c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides);
std::vector<std::size_t>(e_ms_ns_lengths.begin(), e_ms_ns_lengths.end()),
std::vector<std::size_t>(e_ms_ns_strides.begin(), e_ms_ns_strides.end()));
using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM, using ReferenceOpInstance = ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN, NumDimN,
...@@ -437,7 +420,7 @@ int main(int argc, char* argv[]) ...@@ -437,7 +420,7 @@ int main(int argc, char* argv[])
} }
} }
return ck::utils::check_err(e_ms_ns_device_result.mData, e_ms_ns_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1;
} }
return 0; return 0;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment