"src/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "8848a764df5f55e15d75a54447f75110d4c11045"
Unverified Commit 01801633 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[Common] Added Alignment Requirements for CuBLAS heuristics (#845)



* added alignment requirements for CuBLAS heuristics
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* minor rewords
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added unit test for gemm with unaligned inputs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added pytest skip if fp8 is not available
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* changed offset so that it has alignment with 128
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent d705f7ff
......@@ -29,6 +29,10 @@ from transformer_engine.pytorch import (
get_cpu_offload_context,
)
from transformer_engine.common import recipe
import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
# Only run FP8 tests on H100.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -924,3 +928,61 @@ def test_model_multiple_cast():
y2 = m(a)
assert y2.dtype == torch.float16
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype)
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset*2:], (N, N))
_, _, _ = gemm(
A=weight,
B=inp,
dtype=datatype,
workspace=get_workspace())
torch.cuda.synchronize()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
offset = 16
scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype)
fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, N
scale_factor = 1.
meta_inp = create_meta(scale_factor, nb_inp_scales)
meta_weight = create_meta(scale_factor, nb_weight_scales)
inp_type = tex.DType.kFloat8E4M3
weights_type = tex.DType.kFloat8E4M3
outp_type = datatype
scratchpad_fp8 = cast_to_fp8(
scratchpad,
meta_weight,
fp8_tensor_inp,
inp_type)
inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
_, _ = fp8_gemm(
weight_fp8,
meta_weight.scale_inv,
fp8_tensor_weight,
inp_type,
inp_fp8,
meta_inp.scale_inv,
fp8_tensor_inp,
weights_type,
outp_type,
get_workspace(),
bias=None,
use_bias=False,
use_split_accumulator=False)
torch.cuda.synchronize()
......@@ -9,6 +9,7 @@
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cstdint>
#include <transformer_engine/transformer_engine.h>
#include "../common.h"
......@@ -34,6 +35,16 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
}
}
uint32_t _getAlignment(uintptr_t address) {
// alignment are in bytes
uint32_t alignment = 256;
for (; ; alignment /= 2) {
if (address % alignment == 0) {
return alignment;
}
}
}
} // namespace
namespace transformer_engine {
......@@ -260,6 +271,22 @@ void cublas_gemm(const Tensor *inputA,
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
const auto A_alignment = _getAlignment(reinterpret_cast<uintptr_t>(A));
const auto B_alignment = _getAlignment(reinterpret_cast<uintptr_t>(B));
const auto C_alignment = _getAlignment(reinterpret_cast<uintptr_t>(C));
const auto D_alignment = _getAlignment(reinterpret_cast<uintptr_t>(D));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,
&A_alignment, sizeof(A_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES,
&B_alignment, sizeof(B_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES,
&C_alignment, sizeof(C_alignment)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES,
&D_alignment, sizeof(D_alignment)));
const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc,
Ddesc, preference, 1, &heuristicResult,
......@@ -271,7 +298,6 @@ void cublas_gemm(const Tensor *inputA,
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment