Commit 2cbe1b70 authored by wenjh's avatar wenjh
Browse files

[TEST] Fix build error of test_cublaslt_gemm


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 3e38a2ea
...@@ -111,9 +111,9 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c ...@@ -111,9 +111,9 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
DType dtype = TypeInfo<D_Type>::dtype; DType dtype = TypeInfo<D_Type>::dtype;
// pytorch tensor storage is row-major while cublas/rocblas is column-major // pytorch tensor storage is row-major while cublas/rocblas is column-major
Tensor A("A", { k, m }, atype); Tensor A("A", std::vector<size_t>{ k, m }, atype);
Tensor B("B", { n, k }, btype); Tensor B("B", std::vector<size_t>{ n, k }, btype);
Tensor D("D", { n, m }, dtype); Tensor D("D", std::vector<size_t>{ n, m }, dtype);
Tensor bias; Tensor bias;
if(use_bias){ if(use_bias){
bias = Tensor("bias", {m}, bias_type); bias = Tensor("bias", {m}, bias_type);
...@@ -149,7 +149,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c ...@@ -149,7 +149,7 @@ void performTest(bool use_bias, bool use_gelu, const size_t m, const size_t k, c
} }
#endif #endif
Tensor Workspace("Workspace", { 33554432 }, DType::kByte); Tensor Workspace("Workspace", std::vector<size_t>{ 33554432 }, DType::kByte);
//perform the gemm in GPU //perform the gemm in GPU
nvte_cublas_gemm(A.data(), nvte_cublas_gemm(A.data(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment