"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "94ba75d7f470a412a5d4f3cea3728792f49b38f7"
Commit e8f92b93 authored by yuguo's avatar yuguo
Browse files

[DCU] fix batchgemm

parent c37084b9
...@@ -731,20 +731,6 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan ...@@ -731,20 +731,6 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan
#endif #endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
static inline int getIntEnv(const char *name, int defval, int minval)
{
int val = defval;
const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0')
{
val = atoi(env);
if (val < minval)
{
val = minval;
}
}
return val;
}
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const NVTETensor *bias, NVTETensor *pre_gelu_out,
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
"""Module level PyTorch APIs""" """Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
from .grouped_linear import GroupedLinear, BatchedLinear from .grouped_linear import GroupedLinear
from .batched_linear import BatchedLinear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .rmsnorm import RMSNorm from .rmsnorm import RMSNorm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment