Unverified Commit 01801633 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[Common] Added Alignment Requirements for CuBLAS heuristics (#845)



* added alignment requirements for CuBLAS heuristics
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* minor rewords
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added unit test for gemm with unaligned inputs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added pytest skip if fp8 is not available
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed offset so that it has alignment with 128
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent d705f7ff
...@@ -29,6 +29,10 @@ from transformer_engine.pytorch import ( ...@@ -29,6 +29,10 @@ from transformer_engine.pytorch import (
get_cpu_offload_context, get_cpu_offload_context,
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
# Only run FP8 tests on H100. # Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...@@ -924,3 +928,61 @@ def test_model_multiple_cast(): ...@@ -924,3 +928,61 @@ def test_model_multiple_cast():
y2 = m(a) y2 = m(a)
assert y2.dtype == torch.float16 assert y2.dtype == torch.float16
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype)
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset*2:], (N, N))
_, _, _ = gemm(
A=weight,
B=inp,
dtype=datatype,
workspace=get_workspace())
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
offset = 16
scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype)
fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, N
scale_factor = 1.
meta_inp = create_meta(scale_factor, nb_inp_scales)
meta_weight = create_meta(scale_factor, nb_weight_scales)
inp_type = tex.DType.kFloat8E4M3
weights_type = tex.DType.kFloat8E4M3
outp_type = datatype
scratchpad_fp8 = cast_to_fp8(
scratchpad,
meta_weight,
fp8_tensor_inp,
inp_type)
inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
_, _ = fp8_gemm(
weight_fp8,
meta_weight.scale_inv,
fp8_tensor_weight,
inp_type,
inp_fp8,
meta_inp.scale_inv,
fp8_tensor_inp,
weights_type,
outp_type,
get_workspace(),
bias=None,
use_bias=False,
use_split_accumulator=False)
torch.cuda.synchronize()
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <cublasLt.h> #include <cublasLt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda.h> #include <cuda.h>
#include <cstdint>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../common.h" #include "../common.h"
...@@ -34,6 +35,16 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { ...@@ -34,6 +35,16 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
} }
} }
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
for (; ; alignment /= 2) {
if (address % alignment == 0) {
return alignment;
}
}
}
} // namespace } // namespace
namespace transformer_engine { namespace transformer_engine {
...@@ -260,6 +271,22 @@ void cublas_gemm(const Tensor *inputA, ...@@ -260,6 +271,22 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize))); &workspaceSize, sizeof(workspaceSize)));
const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(A));
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
&A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
&B_alignment, sizeof(B_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
&C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
&D_alignment, sizeof(D_alignment)));
const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc,
Ddesc, preference, 1, &heuristicResult, Ddesc, preference, 1, &heuristicResult,
...@@ -271,7 +298,6 @@ void cublas_gemm(const Tensor *inputA, ...@@ -271,7 +298,6 @@ void cublas_gemm(const Tensor *inputA,
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, NVTE_CHECK_CUBLAS(cublasLtMatmul(handle,
operationDesc, operationDesc,
static_cast<const void*>(&one), /* alpha */ static_cast<const void*>(&one), /* alpha */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment