Commit b62c420f authored by Max Rietmann's avatar Max Rietmann
Browse files

Moved permute out of bwd kernel & qy shared cache

putting qy in shared is a little faster

Changing internal memory layout means we can leave code in standard shape and
only change layout external to kernel
parent 5f051c97
......@@ -212,5 +212,98 @@ class TestNeighborhoodAttentionS2(unittest.TestCase):
self.assertTrue(torch.allclose(grad, grad_ref, atol=atol, rtol=rtol), f"Parameter gradient mismatch")
@parameterized.expand(
[
# self attention
[1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
]
)
def test_perf(self, batch_size, channels, heads, in_shape, out_shape, grid_in, grid_out, atol, rtol):
# this test only makes sense when CUDA version is available
if torch.cuda.is_available():
if not _cuda_extension_available:
print("WARNING: Problem loading CUDA attention module")
return
# extract some parameters
nlat_in, nlon_in = in_shape
nlat_out, nlon_out = out_shape
# TODO: this test seems hardcoded for GPU. Is this necessary?
k_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
k_gpu.requires_grad = False
v_gpu = torch.randn(batch_size, channels, nlat_in, nlon_in, dtype=torch.float32, device="cuda:0")
v_gpu.requires_grad = False
q_gpu = torch.randn(batch_size, channels, nlat_out, nlon_out, dtype=torch.float32, device="cuda:0")
q_gpu.requires_grad = False
# set up layers
time_layer_setup_start = torch.cuda.Event(enable_timing=True)
time_layer_setup_end = torch.cuda.Event(enable_timing=True)
time_layer_setup_start.record()
att_gpu = NeighborhoodAttentionS2(in_channels=channels, num_heads=heads,
in_shape=in_shape, out_shape=out_shape,
grid_in=grid_in, grid_out=grid_out, bias=True).to("cuda:0")
time_layer_setup_end.record()
torch.cuda.synchronize()
print(f"Layer setup: {time_layer_setup_start.elapsed_time(time_layer_setup_end)} ms")
# random weights
with torch.no_grad():
att_gpu.q_weights.normal_()
att_gpu.k_weights.normal_()
att_gpu.v_weights.normal_()
att_gpu.q_bias.normal_()
att_gpu.k_bias.normal_()
att_gpu.v_bias.normal_()
# time forward pass
for i in range(2):
# warmup
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_start = torch.cuda.Event(enable_timing=True)
time_forward_end = torch.cuda.Event(enable_timing=True)
time_forward_start.record()
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
time_forward_end.record()
torch.cuda.synchronize()
print(f"Forward execution: {time_forward_start.elapsed_time(time_forward_end)} ms")
# sync weights:
with torch.no_grad():
att_gpu.q_weights.copy_(att_gpu.q_weights)
att_gpu.k_weights.copy_(att_gpu.k_weights)
att_gpu.v_weights.copy_(att_gpu.v_weights)
att_gpu.q_bias.copy_(att_gpu.q_bias)
att_gpu.k_bias.copy_(att_gpu.k_bias)
att_gpu.v_bias.copy_(att_gpu.v_bias)
q_gpu = q_gpu.detach().clone().to(self.device, memory_format=torch.channels_last)
q_gpu.requires_grad = True
k_gpu = k_gpu.detach().clone().to(self.device, memory_format=torch.channels_last)
k_gpu.requires_grad = True
v_gpu = v_gpu.detach().clone().to(self.device, memory_format=torch.channels_last)
v_gpu.requires_grad = True
out_gpu = att_gpu(q_gpu, k_gpu, v_gpu)
out_grad = torch.randn(out_gpu.shape, dtype=torch.float32, device="cuda:0").to(memory_format=torch.channels_last)
time_backward_start = torch.cuda.Event(enable_timing=True)
time_backward_end = torch.cuda.Event(enable_timing=True)
print("q_gpu_stride=",q_gpu.stride())
for i in range(2):
# warmup
out_gpu.backward(out_grad, retain_graph=True)
print("out_grad_stride=",out_grad.stride())
time_backward_start.record()
out_gpu.backward(out_grad)
time_backward_end.record()
torch.cuda.synchronize()
print(f"Backward execution: {time_backward_start.elapsed_time(time_backward_end)} ms")
if __name__ == "__main__":
unittest.main()
......@@ -29,6 +29,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "attention.cuh"
#include "c10/core/MemoryFormat.h"
#include <ATen/core/TensorAccessor.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
......@@ -36,6 +37,7 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDAUtils.h>
#include <ctime>
#include <cub/cub.cuh>
#include <limits>
......@@ -61,6 +63,26 @@
}}
#endif
#include <iostream>
#include <chrono>
#include <string>
class ScopeTimer {
public:
explicit ScopeTimer(const std::string& label = "")
: label_(label), start_(std::chrono::high_resolution_clock::now()) {}
~ScopeTimer() {
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start_);
std::cout << label_ << "Elapsed time: " << elapsed.count() << " ms" << std::endl;
}
private:
std::string label_;
std::chrono::high_resolution_clock::time_point start_;
};
__device__ static float atomicMax(float* address, float val)
{
int* address_as_i = (int*) address;
......@@ -738,10 +760,11 @@ __launch_bounds__(BDIM_X)
const torch::PackedTensorAccessor32<float, 1, torch::RestrictPtrTraits> quad_weights) {
extern __shared__ float sh[];
float* sh_alpha_k = sh + threadIdx.y * num_channels * 4;
float* sh_alpha_k = sh + threadIdx.y * num_channels * 5;
float* sh_alpha_vw = sh_alpha_k + num_channels;
float* sh_alpha_kvw = sh_alpha_vw + num_channels;
float* sh_dy = sh_alpha_kvw + num_channels;
float *sh_dy = sh_alpha_kvw + num_channels;
float* sh_qy = sh_dy + num_channels;
// (optionally, could use more shared memory for other intermediates)
const uint64_t batchId = blockIdx.y;
......@@ -756,7 +779,8 @@ __launch_bounds__(BDIM_X)
sh_alpha_k[chan] = 0.0f;
sh_alpha_vw[chan] = 0.0f;
sh_alpha_kvw[chan] = 0.0f;
sh_dy[chan] = dy[batchId][ho][wo][chan];
sh_dy[chan] = dy[batchId][chan][ho][wo];
sh_qy[chan] = qy[batchId][chan][ho][wo];
}
float alpha_sum = 0.0f;
float qdotk_max = -FLT_MAX;
......@@ -775,7 +799,7 @@ __launch_bounds__(BDIM_X)
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
qdotk_max = max(qdotk_max, qdotk);
......@@ -789,8 +813,8 @@ __launch_bounds__(BDIM_X)
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
gdotv += sh_dy[chan] * vx[batchId][hi][wip][chan];
qdotk += sh_qy[chan] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
......@@ -798,7 +822,7 @@ __launch_bounds__(BDIM_X)
alpha_sum += alpha_inz;
integral += alpha_inz * gdotv;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float kxval = kx[batchId][hi][wip][chan];
float kxval = kx[batchId][chan][hi][wip];
sh_alpha_k[chan] += alpha_inz * kxval;
sh_alpha_vw[chan] += alpha_inz * gdotv;
sh_alpha_kvw[chan] += alpha_inz * kxval * gdotv;
......@@ -809,7 +833,7 @@ __launch_bounds__(BDIM_X)
// Write dydq
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
dydq[batchId][ho][wo][chan] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
dydq[batchId][chan][ho][wo] = (sh_alpha_kvw[chan] * alpha_sum - sh_alpha_vw[chan] * sh_alpha_k[chan]) / (alpha_sum * alpha_sum);
}
// Third pass: accumulate gradients for k and v
......@@ -820,17 +844,17 @@ __launch_bounds__(BDIM_X)
const int wip = (wi + wo) - ((wi + wo) / nlon_in) * nlon_in;
float qdotk = 0.0f, gdotv = 0.0f;
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
qdotk += qy[batchId][ho][wo][chan] * kx[batchId][hi][wip][chan];
gdotv += sh_dy[chan] * vx[batchId][hi][wip][chan];
qdotk += qy[batchId][chan][ho][wo] * kx[batchId][chan][hi][wip];
gdotv += sh_dy[chan] * vx[batchId][chan][hi][wip];
}
qdotk = __warp_sum_cub(qdotk);
gdotv = __warp_sum_cub(gdotv);
float alpha_inz = expf(qdotk - qdotk_max) * quad_weights[hi];
for (int chan = tidx; chan < num_channels; chan += WARP_SIZE) {
float qyval = qy[batchId][ho][wo][chan];
float qyval = qy[batchId][chan][ho][wo];
float dyval = sh_dy[chan];
atomicAdd(&dydk[batchId][hi][wip][chan], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][hi][wip][chan], (alpha_inz / alpha_sum) * dyval);
atomicAdd(&dydk[batchId][chan][hi][wip], qyval * (alpha_inz / alpha_sum) * (gdotv - integral));
atomicAdd(&dydv[batchId][chan][hi][wip], (alpha_inz / alpha_sum) * dyval);
}
}
}
......@@ -1038,23 +1062,42 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
return std::make_tuple(dydk, dydv, dydq);
} else if (version == HOWO_WARP_VERSION) {
ScopeTimer timer("Full s2_attention_bwd_dkvq_kernel_mbT");
// Time this function via C++
time_t start_time, end_time;
start_time = clock();
// Transpose to [batch, ho, wo, channel]
auto kxP = kx.permute({0,2,3,1}).contiguous();
auto vxP = vx.permute({0,2,3,1}).contiguous();
auto qyP = qy.permute({0,2,3,1}).contiguous();
auto dyP = dy.permute({0,2,3,1}).contiguous();
auto dydkP = torch::zeros_like(qyP);
auto dydvP = torch::zeros_like(qyP);
auto dydqP = torch::zeros_like(qyP);
// nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT permute inputs");
// auto* permute_timer = new ScopeTimer("permute inputs");
// auto kxP = kx.permute({0,2,3,1}).contiguous().permute({0,3,1,2});
// auto vxP = vx.permute({0,2,3,1}).contiguous().permute({0,3,1,2});
// auto qyP = qy.permute({0,2,3,1}).contiguous().permute({0,3,1,2});
// auto dyP = dy.permute({0, 2, 3, 1}).contiguous().permute({0, 3, 1, 2});
// cudaDeviceSynchronize();
// delete permute_timer;
// nvtxRangePop();
nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output allocation & zero");
auto dydkP = torch::zeros_like(qy);
auto dydvP = torch::zeros_like(qy);
auto dydqP = torch::zeros_like(qy);
// print strdie of dydkP, dydvP, dydqP
printf("dydkP strides: ");
for(auto& stride_i :dydkP.strides()) {
printf("%ld ", stride_i);
}
printf("\n");
cudaDeviceSynchronize();
nvtxRangePop();
size_t uo_num_channels = kx.size(1);
const int batch_size = kx.size(0);
dim3 block(WARP_SIZE, THREADS/WARP_SIZE);
dim3 grid(DIV_UP(nlat_out*nlon_out, block.y), batch_size);
size_t shared_size = sizeof(float) * uo_num_channels * 4 * block.y; // 4 arrays per warp
size_t shared_size = sizeof(float) * uo_num_channels * 5 * block.y; // 4 arrays per warp
cudaEvent_t start, stop;
float milliseconds = 0;
......@@ -1065,10 +1108,10 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
s2_attention_bwd_dkvq_kernel_mbT<THREADS><<<
grid, block, shared_size, stream>>>(
uo_num_channels, nlon_in, nlat_out, nlon_out,
kxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vxP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dyP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
kx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
vx.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
qy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dy.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydkP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydvP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
dydqP.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
......@@ -1082,17 +1125,22 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> s2_attention_bwd_dkvq_cuda(at::Tens
// [1, 256, 1, (721, 1440), (721, 1440), "equiangular", "equiangular", 1e-5, 1e-5],
// s2_attention_bwd_kernel_mbT execution time: 63.280128 ms
// printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
printf("s2_attention_bwd_kernel_mbT execution time: %f ms\n", milliseconds);
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Permute outputs back to [batch, channel, ho, wo]
auto dydk = dydkP.permute({0,3,1,2}).contiguous();
auto dydv = dydvP.permute({0,3,1,2}).contiguous();
auto dydq = dydqP.permute({0,3,1,2}).contiguous();
return std::make_tuple(dydk, dydv, dydq);
// nvtxRangePush("s2_attention_bwd_dkvq_kernel_mbT output permutation");
// auto* permute_output_timer = new ScopeTimer("permute outputs");
// auto dydk = dydkP.permute({0,3,1,2}).contiguous().permute({0,3,1,2});
// auto dydv = dydvP.permute({0,3,1,2}).contiguous();
// auto dydq = dydqP.permute({0, 3, 1, 2}).contiguous();
// cudaDeviceSynchronize();
// delete permute_output_timer;
// nvtxRangePop();
return std::make_tuple(dydkP, dydvP, dydqP);
} else {
throw std::runtime_error("Invalid kernel version specified");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment