"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6b06c30a65f3ae90cc2bc2cf3359cff741b4e139"
Unverified Commit 58f9060e authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Update int8 gemm config (#2774)

parent bdc1acf6
...@@ -88,10 +88,11 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons ...@@ -88,10 +88,11 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args); auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess) TORCH_CHECK(can_implement == cutlass::Status::kSuccess,
"gemm cannot implement, error: ", cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream); auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess) TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
} }
template <typename ElementOutput, typename ArchTag, typename InstructionShape> template <typename ElementOutput, typename ArchTag, typename InstructionShape>
...@@ -144,7 +145,17 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t ...@@ -144,7 +145,17 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias); scales_b, bias);
} }
} else if (m <= 64 || (m <= 128 && n < 8192)) { } else if (m <= 64) {
if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
} else if (m <= 128 && n < 8192) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>, cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias); scales_b, bias);
......
...@@ -37,8 +37,8 @@ class TestInt8Gemm(unittest.TestCase): ...@@ -37,8 +37,8 @@ class TestInt8Gemm(unittest.TestCase):
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self): def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096] Ms = [1, 128, 512, 1024, 4096, 8192]
Ns = [16, 128, 512, 1024, 4096] Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
Ks = [512, 1024, 4096, 8192, 16384] Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False] bias_opts = [True, False]
out_dtypes = [torch.float16, torch.bfloat16] out_dtypes = [torch.float16, torch.bfloat16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment