Unverified Commit 44401358 authored by triple-mu's avatar triple-mu Committed by GitHub
Browse files

Fix typos and unify size(s)/stride(s) API calls (#8799)

parent 9c7e3924
...@@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options( ...@@ -105,10 +105,10 @@ typename T::Fmha::Arguments args_from_options(
hw_info.device_id = q_nope.device().index(); hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope.sizes()[0]; int batches = q_nope.size(0);
int page_count_per_seq = page_table.sizes()[1]; int page_count_per_seq = page_table.size(1);
int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; int page_count_total = kv_c_and_k_pe_cache.size(0);
int page_size = kv_c_and_k_pe_cache.sizes()[1]; int page_size = kv_c_and_k_pe_cache.size(1);
int max_seq_len = page_size * page_count_per_seq; int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH; using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD; using TileShapeD = typename T::TileShapeD;
...@@ -220,7 +220,7 @@ void cutlass_mla_decode( ...@@ -220,7 +220,7 @@ void cutlass_mla_decode(
auto in_dtype = q_nope.dtype(); auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.sizes()[1]; const int page_size = kv_c_and_k_pe_cache.size(1);
// NOTE(alcanderian): IsPersistent has bug with manual split_kv. // NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
......
...@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch: ...@@ -640,9 +640,9 @@ void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch:
TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]") TORCH_CHECK(output.size(0) == num_tokens, "required output.shape[0] == mat_a.shape[0]")
TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]") TORCH_CHECK(output.size(1) == hd_out, "required output.shape[1] == mat_b.shape[1]")
TORCH_CHECK(mat_a.strides()[1] == 1); // Row-major TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); // Row-major
TORCH_CHECK(output.strides()[1] == 1); // Row-major TORCH_CHECK(output.stride(1) == 1, "output must be a row major tensor"); // Row-major
TORCH_CHECK(mat_b.strides()[0] == 1); // Column-major TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor"); // Column-major
auto const data_type = mat_a.scalar_type(); auto const data_type = mat_a.scalar_type();
TORCH_CHECK( TORCH_CHECK(
......
...@@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -353,7 +353,7 @@ torch::Tensor fp8_blockwise_scaled_mm(
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK( TORCH_CHECK(
......
...@@ -1080,7 +1080,7 @@ torch::Tensor fp8_scaled_mm( ...@@ -1080,7 +1080,7 @@ torch::Tensor fp8_scaled_mm(
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK( TORCH_CHECK(
......
...@@ -672,7 +672,7 @@ torch::Tensor int8_scaled_mm( ...@@ -672,7 +672,7 @@ torch::Tensor int8_scaled_mm(
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor");
TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor");
TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_b.stride(0) == 1, "mat_b must be a column major tensor");
TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied");
TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment"); TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment");
TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment"); TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment");
......
...@@ -273,20 +273,20 @@ void cutlass_scaled_fp4_mm_sm100a( ...@@ -273,20 +273,20 @@ void cutlass_scaled_fp4_mm_sm100a(
TORCH_CHECK(A.dim() == 2, "a must be a matrix"); TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix"); TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK( TORCH_CHECK(
A.sizes()[1] == B.sizes()[1], A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", "a and b shapes cannot be multiplied (",
A.sizes()[0], A.size(0),
"x", "x",
A.sizes()[1], A.size(1),
" and ", " and ",
B.sizes()[0], B.size(0),
"x", "x",
B.sizes()[1], B.size(1),
")"); ")");
auto const m = A.sizes()[0]; auto const m = A.size(0);
auto const n = B.sizes()[0]; auto const n = B.size(0);
auto const k = A.sizes()[1] * 2; auto const k = A.size(1) * 2;
constexpr int alignment = 32; constexpr int alignment = 32;
TORCH_CHECK( TORCH_CHECK(
...@@ -294,9 +294,9 @@ void cutlass_scaled_fp4_mm_sm100a( ...@@ -294,9 +294,9 @@ void cutlass_scaled_fp4_mm_sm100a(
"Expected k to be divisible by ", "Expected k to be divisible by ",
alignment, alignment,
", but got a shape: (", ", but got a shape: (",
A.sizes()[0], A.size(0),
"x", "x",
A.sizes()[1], A.size(1),
"), k: ", "), k: ",
k, k,
"."); ".");
...@@ -305,9 +305,9 @@ void cutlass_scaled_fp4_mm_sm100a( ...@@ -305,9 +305,9 @@ void cutlass_scaled_fp4_mm_sm100a(
"Expected n to be divisible by ", "Expected n to be divisible by ",
alignment, alignment,
", but got b shape: (", ", but got b shape: (",
B.sizes()[0], B.size(0),
"x", "x",
B.sizes()[1], B.size(1),
")."); ").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
...@@ -320,37 +320,37 @@ void cutlass_scaled_fp4_mm_sm100a( ...@@ -320,37 +320,37 @@ void cutlass_scaled_fp4_mm_sm100a(
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK( TORCH_CHECK(
A_sf.sizes()[1] == B_sf.sizes()[1], A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (", "scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], A_sf.size(0),
"x", "x",
A_sf.sizes()[1], A_sf.size(1),
" and ", " and ",
B_sf.sizes()[0], B_sf.size(0),
"x", "x",
B_sf.sizes()[1], B_sf.size(1),
")"); ")");
TORCH_CHECK( TORCH_CHECK(
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", "scale_a must be padded and swizzled to a shape (",
rounded_m, rounded_m,
"x", "x",
rounded_k, rounded_k,
"), but got a shape (", "), but got a shape (",
A_sf.sizes()[0], A_sf.size(0),
"x", "x",
A_sf.sizes()[1], A_sf.size(1),
")"); ")");
TORCH_CHECK( TORCH_CHECK(
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", "scale_b must be padded and swizzled to a shape (",
rounded_n, rounded_n,
"x", "x",
rounded_k, rounded_k,
"), but got a shape (", "), but got a shape (",
B_sf.sizes()[0], B_sf.size(0),
"x", "x",
B_sf.sizes()[1], B_sf.size(1),
")"); ")");
auto out_dtype = D.dtype(); auto out_dtype = D.dtype();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment