Commit cb8b7a88 authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Merge remote-tracking branch 'origin/master' into IFU-master-2022-07-29

parents 51783cc7 c97ebfab
......@@ -886,7 +886,7 @@ void HostApplyLayerNorm(
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const int warp_size = at::cuda::warp_size();
dim3 threads(warp_size ,4, 1); // MI100 wavefront/warp = 64
#ifdef __HIP_PLATFORM_HCC__
// Optimization for ROCm MI100
......@@ -915,7 +915,7 @@ void HostApplyRMSNorm(
const V* gamma)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const int warp_size = at::cuda::warp_size();
const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
dim3 threads(warp_size,4,1);
......@@ -1009,7 +1009,7 @@ void HostLayerNormGradient(
)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const int warp_size = at::cuda::warp_size();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
......@@ -1092,7 +1092,7 @@ void HostRMSNormGradient(
V* grad_gamma)
{
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int warp_size = at::cuda::getCurrentDeviceProperties()->warpSize;
const int warp_size = at::cuda::warp_size();
if (gamma != NULL) {
const int part_size = warp_size;
const dim3 threads2(warp_size,4,1);
......
......@@ -262,6 +262,7 @@ class TestAutocastFusedLayerNorm(unittest.TestCase):
with self.subTest(f"{dtype}-{elementwise_affine}"):
self._run_test(dtype, elementwise_affine)
@unittest.skip("Skipped on ROCm5.2 due to the failure of reproducing the issue locally. (Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!) Please refer to https://github.com/ROCmSoftwarePlatform/apex/pull/78")
class TestAutocastFusedRMSNorm(unittest.TestCase):
bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4)
bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment