"vscode:/vscode.git/clone" did not exist on "d5f60b119c50d10376ed5213963cf6d96a8e3fc9"
Commit 8e75ab95 authored by Shenggan's avatar Shenggan
Browse files

add inject_openfold

parent 90019096
......@@ -74,6 +74,8 @@ def get_data_parallel_group():
def get_tensor_model_parallel_world_size():
if not dap_is_initialized():
return 1
"""Return world size for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
......@@ -82,6 +84,8 @@ def get_tensor_model_parallel_world_size():
def get_tensor_model_parallel_rank():
if not dap_is_initialized():
return 0
"""Return my rank for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_RANK
if _TENSOR_MODEL_PARALLEL_RANK is not None:
......
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cassert>
#include <vector>
......@@ -74,6 +75,8 @@ std::vector<at::Tensor> layer_norm_affine(at::Tensor input, at::IntArrayRef norm
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty({n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
......@@ -104,6 +107,8 @@ std::vector<at::Tensor> layer_norm_gradient_affine(at::Tensor dout, at::Tensor m
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
......
......@@ -64,15 +64,27 @@ __global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamm
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
int lane_id = threadidx_y;
if (row_offset < rows) {
float buf[32];
float thread_mean;
float thread_m2;
float thread_count;
float thread_mean = 0.f;
float thread_m2 = 0.f;
float thread_count = 0.f;
float warp_mean;
float warp_m2;
......@@ -81,13 +93,13 @@ __global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamm
float* row_input = input + row_offset * cols;
float* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
......@@ -102,16 +114,17 @@ __global__ void fastfold_layernorm_fp32(float* input, float* output, float* gamm
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
buf[i] * gamma[lane_id * cols_per_thread + i] + beta[lane_id * cols_per_thread + i];
}
}
}
__global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* output,
......@@ -120,15 +133,27 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
int lane_id = threadidx_y;
if (row_offset < rows) {
float buf[32];
float thread_mean;
float thread_m2;
float thread_count;
float thread_mean = 0.f;
float thread_m2 = 0.f;
float thread_count = 0.f;
float warp_mean;
float warp_m2;
......@@ -137,13 +162,13 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
at::BFloat16* row_input = input + row_offset * cols;
at::BFloat16* row_output = output + row_offset * cols;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
WelfordOnline(buf[i], &thread_mean, &thread_m2, &thread_count);
}
......@@ -158,23 +183,24 @@ __global__ void fastfold_layernorm_bfp16(at::BFloat16* input, at::BFloat16* outp
invvar[row_offset] = row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = (buf[i] - row_mean) * row_inv_var;
}
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(buf[i]) * gamma[lane_id * cols_per_thread + i] +
beta[lane_id * cols_per_thread + i];
}
}
}
void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar, at::Tensor* input,
int rows, int cols, at::IntArrayRef normalized_shape, at::Tensor* gamma,
at::Tensor* beta, double epsilon) {
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (output->dtype() == torch::kFloat32) {
......
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
......@@ -33,41 +34,53 @@ __global__ void fastfold_softmax_fp32(float *input, float *output, int rows, int
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = row_input[lane_id * cols_per_thread + i];
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
}
__global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output, int rows,
......@@ -75,42 +88,55 @@ __global__ void fastfold_softmax_bfp16(at::BFloat16 *input, at::BFloat16 *output
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
buf[i] = static_cast<float>(row_input[lane_id * cols_per_thread + i]);
}
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
}
}
}
__global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float *d_input, int rows,
......@@ -118,37 +144,49 @@ __global__ void fastfold_softmax_grad_fp32(float *d_output, float *output, float
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_d_output = d_output + row_offset * cols;
float *row_output = output + row_offset * cols;
float *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = row_output[lane_id * cols_per_thread + i];
dy_buf[i] = row_d_output[lane_id * cols_per_thread + i];
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
row_d_input[lane_id * cols_per_thread + i] = (dy_buf[i] - warp_sum) * y_buf[i];
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_d_input[lane_id * cols_this_thread + i] = (dy_buf[i] - warp_sum) * y_buf[i];
}
}
}
......@@ -157,46 +195,60 @@ __global__ void fastfold_softmax_grad_bfp16(at::BFloat16 *d_output, at::BFloat16
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_d_output = d_output + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *row_d_input = d_input + row_offset * cols;
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<float>(row_d_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>((dy_buf[i] - warp_sum) * y_buf[i]);
}
}
}
at::Tensor softmax(at::Tensor input, int rows, int cols) {
CHECK_INPUT(input);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
at::Tensor output = at::empty_like(input);
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
......@@ -212,9 +264,10 @@ at::Tensor softmax(at::Tensor input, int rows, int cols) {
at::Tensor softmax_gradient(at::Tensor d_output, at::Tensor output, int rows, int cols) {
CHECK_INPUT(output);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
at::Tensor grad_input = at::empty_like(output);
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
......@@ -237,18 +290,29 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
......@@ -257,25 +321,26 @@ __global__ void fastfold_softmax_scale_mask_fp32(float *input, float *mask, floa
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
}
__global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
......@@ -284,18 +349,29 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
......@@ -304,36 +380,38 @@ __global__ void fastfold_softmax_scale_mask_bfp16(at::BFloat16 *input, at::BFloa
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
}
}
}
at::Tensor fused_scale_mask_softmax_forward(at::Tensor input, at::Tensor mask, int rows, int cols,
float scale) {
CHECK_INPUT(input);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
......@@ -355,13 +433,24 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_d_output = d_output + row_offset * cols;
float *row_output = output + row_offset * cols;
float *row_d_input = d_input + row_offset * cols;
......@@ -369,23 +458,23 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = row_output[lane_id * cols_per_thread + i];
dy_buf[i] = row_d_output[lane_id * cols_per_thread + i];
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
scale * ((dy_buf[i] - warp_sum) * y_buf[i]);
......@@ -393,6 +482,7 @@ __global__ void fastfold_softmax_scale_mask_grad_fp32(float *d_output, float *ou
row_d_input = 0;
}
}
}
}
__global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, at::BFloat16 *output,
......@@ -401,13 +491,24 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float y_buf[32];
float dy_buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_d_output = d_output + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *row_d_input = d_input + row_offset * cols;
......@@ -415,23 +516,23 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
y_buf[i] = static_cast<float>(row_output[lane_id * cols_per_thread + i]);
dy_buf[i] = static_cast<float>(row_d_output[lane_id * cols_per_thread + i]);
}
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_sum += y_buf[i] * dy_buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
if (mask_ptr[lane_id * cols_per_thread + i] != 0) {
row_d_input[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(scale * ((dy_buf[i] - warp_sum) * y_buf[i]));
......@@ -439,16 +540,18 @@ __global__ void fastfold_softmax_scale_mask_grad_bfp16(at::BFloat16 *d_output, a
row_d_input = 0;
}
}
}
}
at::Tensor fused_scale_mask_softmax_backward(at::Tensor d_output, at::Tensor output,
at::Tensor mask, int rows, int cols, float scale) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
int head = output.sizes()[2];
at::Tensor grad_input = at::empty_like(output);
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
......@@ -473,19 +576,30 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
float *row_input = input + row_offset * cols;
float *row_output = output + row_offset * cols;
float *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
float *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
......@@ -495,25 +609,26 @@ __global__ void fastfold_softmax_scale_mask_bias_fp32(float *input, float *mask,
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] = __fdividef(buf[i], warp_sum);
}
}
}
__global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::BFloat16 *mask,
......@@ -522,19 +637,30 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
int threadidx_x = threadIdx.x / 32;
int threadidx_y = threadIdx.x % 32;
int row_offset = blockIdx.x * 4 + threadidx_x;
int cols_per_thread = cols / 32;
int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = (cols / cols_per_thread);
if (threadidx_y == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
}
else if (threadidx_y > last_y) {
cols_this_thread = 0;
}
float buf[32];
int lane_id = threadidx_y;
if (row_offset < rows) {
at::BFloat16 *row_input = input + row_offset * cols;
at::BFloat16 *row_output = output + row_offset * cols;
at::BFloat16 *mask_ptr = mask + ((row_offset / (head * cols)) * cols);
at::BFloat16 *bias_ptr = bias + ((row_offset % (head * cols)) * cols);
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
if (mask_ptr[lane_id * cols_per_thread + i] == 0) {
buf[i] = -1 * CUDART_INF_F;
} else {
......@@ -544,26 +670,27 @@ __global__ void fastfold_softmax_scale_mask_bias_bfp16(at::BFloat16 *input, at::
}
float thread_max = -1 * CUDART_INF_F;
#pragma unroll
for (int i = 0; i < cols_per_thread; i++) {
#pragma unroll
for (int i = 0; i < cols_this_thread; i++) {
thread_max = max(thread_max, buf[i]);
}
float warp_max = WarpAllReduceMax(thread_max);
float thread_sum = 0.f;
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
buf[i] = __expf(buf[i] - warp_max);
thread_sum += buf[i];
}
float warp_sum = WarpAllReduceSum(thread_sum);
#pragma unroll
for (int i = 0; i < cols_per_thread; ++i) {
#pragma unroll
for (int i = 0; i < cols_this_thread; ++i) {
row_output[lane_id * cols_per_thread + i] =
static_cast<at::BFloat16>(__fdividef(buf[i], warp_sum));
}
}
}
at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor mask, at::Tensor bias,
......@@ -571,10 +698,11 @@ at::Tensor fused_scale_mask_bias_softmax_forward(at::Tensor input, at::Tensor ma
CHECK_INPUT(input);
CHECK_INPUT(mask);
CHECK_INPUT(bias);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
int head = input.sizes()[2];
at::Tensor output = at::empty_like(input);
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (input.dtype() == torch::kFloat32) {
......@@ -596,10 +724,11 @@ at::Tensor fused_scale_mask_bias_softmax_backward(at::Tensor d_output, at::Tenso
int cols, float scale) {
CHECK_INPUT(output);
CHECK_INPUT(mask);
const at::cuda::OptionalCUDAGuard device_guard(device_of(mask));
int head = output.sizes()[2];
at::Tensor grad_input = at::empty_like(output);
int grid = rows / 4;
int grid = (rows + 3) / 4;
dim3 block(128);
if (output.dtype() == torch::kFloat32) {
......
......@@ -9,14 +9,14 @@ def bias_sigmod_ele(y, bias, z):
@torch.jit.script
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=training)
out = residual + out
return out
@torch.jit.script
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
prob: float) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
dropout_mask: torch.Tensor, Z_raw: torch.Tensor, prob: float,
training: bool) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=training) * (g * (ab + b))
......@@ -50,7 +50,7 @@ class MSARowAttentionWithPairBias(nn.Module):
M = self.attention(M, M_mask, (b, work))
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop, training=self.training)
class MSAColumnAttention(nn.Module):
......
......@@ -65,7 +65,8 @@ class TriangleMultiplicationOutgoing(nn.Module):
g,
dropout_mask,
Z_raw,
prob=self.p_drop)
prob=self.p_drop,
training=self.training)
class TriangleMultiplicationIncoming(nn.Module):
......@@ -103,10 +104,7 @@ class TriangleMultiplicationIncoming(nn.Module):
left_proj_act = gather_async_opp(left_proj_act, work, dim=2)
p = torch.matmul(
permute_final_dims(left_proj_act, (2, 1, 0)),
right_proj_act
)
p = torch.matmul(permute_final_dims(left_proj_act, (2, 1, 0)), right_proj_act)
ab = permute_final_dims(p, (1, 2, 0))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
......@@ -117,7 +115,8 @@ class TriangleMultiplicationIncoming(nn.Module):
g,
dropout_mask,
Z_raw,
prob=self.p_drop)
prob=self.p_drop,
training=self.training)
class TriangleAttentionStartingNode(nn.Module):
......@@ -156,7 +155,12 @@ class TriangleAttentionStartingNode(nn.Module):
Z = self.attention(Z, Z_mask, (b, work))
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class TriangleAttentionEndingNode(nn.Module):
......@@ -197,7 +201,12 @@ class TriangleAttentionEndingNode(nn.Module):
Z = Z.transpose(-2, -3)
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
return bias_dropout_add(Z,
self.out_bias,
dropout_mask,
Z_raw,
prob=self.p_drop,
training=self.training)
class PairStack(nn.Module):
......@@ -209,10 +218,20 @@ class PairStack(nn.Module):
self.n_head = 4
self.hidden_c = int(d_pair / self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop, c=d_pair)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop, c=d_pair)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop, c=self.hidden_c, n_head=self.n_head)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop, c=self.hidden_c, n_head=self.n_head)
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair,
p_drop=p_drop,
c=d_pair)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair,
p_drop=p_drop,
c=d_pair)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair,
p_drop=p_drop,
c=self.hidden_c,
n_head=self.n_head)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair,
p_drop=p_drop,
c=self.hidden_c,
n_head=self.n_head)
self.PairTransition = Transition(d=d_pair)
def forward(self, pair, pair_mask):
......
from .inject_openfold import inject_openfold
__all__ = ['inject_openfold']
\ No newline at end of file
from typing import Tuple, Optional
import torch
import torch.nn as nn
from fastfold.model import MSAStack, OutProductMean, PairStack
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.distributed.comm import gather, scatter
class EvoformerBlock(nn.Module):
def __init__(self, c_m: int, c_z: int, first_block: bool, last_block: bool):
super(EvoformerBlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
m = self.msa_stack(m, z, msa_mask)
z = z + self.communication(m, msa_mask)
m, work = All_to_All_Async.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = All_to_All_Async_Opp.apply(m, work, 1, 2)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
m = gather(m, dim=0)
z = gather(z, dim=0)
return m, z
def copy_layernorm(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
model_fast.bias.copy_(model_ori.bias)
def copy_linear(model_fast, model_ori):
model_fast.weight.copy_(model_ori.weight)
if model_fast.use_bias:
model_fast.bias.copy_(model_ori.bias)
def copy_qkv_linear(model_fast, ori_q, ori_k, ori_v):
model_fast.weight.copy_(torch.cat((ori_q.weight, ori_k.weight, ori_v.weight), dim=0))
def copy_attention(model_fast, model_ori):
copy_qkv_linear(model_fast.to_qkv, model_ori.linear_q, model_ori.linear_k, model_ori.linear_v)
copy_linear(model_fast.gating_linear, model_ori.linear_g)
copy_linear(model_fast.o_linear, model_ori.linear_o)
try:
model_fast.gating_bias.copy_(model_ori.linear_g.bias)
except:
print("no gating_bias need copy")
def copy_left_right(model_fast, ori_left, ori_right):
model_fast.weight.copy_(torch.cat((ori_left.weight, ori_right.weight), dim=0))
model_fast.bias.copy_(torch.cat((ori_left.bias, ori_right.bias), dim=0))
def copy_transition(model_fast, model_ori):
copy_layernorm(model_fast.norm, model_ori.layer_norm)
copy_linear(model_fast.linear1, model_ori.linear_1)
copy_linear(model_fast.linear2, model_ori.linear_2)
def copy_triangle(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm_in)
copy_layernorm(model_fast.layernorm2, model_ori.layer_norm_out)
copy_linear(model_fast.output_gate, model_ori.linear_g)
copy_linear(model_fast.output_projection, model_ori.linear_z)
model_fast.output_bias.copy_(model_ori.linear_z.bias)
copy_left_right(model_fast.left_right_projection, model_ori.linear_a_p, model_ori.linear_b_p)
copy_left_right(model_fast.left_right_gate, model_ori.linear_a_g, model_ori.linear_b_g)
def copy_triangle_att(model_fast, model_ori):
copy_layernorm(model_fast.layernorm1, model_ori.layer_norm)
copy_linear(model_fast.linear_b, model_ori.linear)
copy_attention(model_fast.attention, model_ori.mha)
model_fast.out_bias.copy_(model_ori.mha.linear_o.bias)
def copy_para(block_fast, block_ori):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormM,
block_ori.msa_att_row.layer_norm_m)
copy_layernorm(block_fast.msa_stack.MSARowAttentionWithPairBias.layernormZ,
block_ori.msa_att_row.layer_norm_z)
copy_attention(block_fast.msa_stack.MSARowAttentionWithPairBias.attention,
block_ori.msa_att_row.mha)
block_fast.msa_stack.MSARowAttentionWithPairBias.linear_b_weights.copy_(
block_ori.msa_att_row.linear_z.weight)
block_fast.msa_stack.MSARowAttentionWithPairBias.out_bias.copy_(
block_ori.msa_att_row.mha.linear_o.bias)
# MSAColumnAttention
copy_layernorm(block_fast.msa_stack.MSAColumnAttention.layernormM,
block_ori.msa_att_col._msa_att.layer_norm_m)
copy_attention(block_fast.msa_stack.MSAColumnAttention.attention,
block_ori.msa_att_col._msa_att.mha)
# MSATransition
copy_transition(block_fast.msa_stack.MSATransition, block_ori.core.msa_transition)
# communication
copy_layernorm(block_fast.communication.layernormM,
block_ori.core.outer_product_mean.layer_norm)
copy_linear(block_fast.communication.linear_a, block_ori.core.outer_product_mean.linear_1)
copy_linear(block_fast.communication.linear_b, block_ori.core.outer_product_mean.linear_2)
copy_linear(block_fast.communication.o_linear, block_ori.core.outer_product_mean.linear_out)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle(block_fast.pair_stack.TriangleMultiplicationOutgoing, block_ori.core.tri_mul_out)
# TriangleMultiplicationIncoming
copy_triangle(block_fast.pair_stack.TriangleMultiplicationIncoming, block_ori.core.tri_mul_in)
# TriangleAttentionStartingNode
copy_triangle_att(block_fast.pair_stack.TriangleAttentionStartingNode,
block_ori.core.tri_att_start)
copy_triangle_att(block_fast.pair_stack.TriangleAttentionEndingNode, block_ori.core.tri_att_end)
copy_transition(block_fast.pair_stack.PairTransition, block_ori.core.pair_transition)
def inject_openfold(model):
with torch.no_grad():
fastfold_blocks = nn.ModuleList()
for block_id, openfold_block in enumerate(model.evoformer.blocks):
c_m = openfold_block.msa_att_row.c_in
c_z = openfold_block.msa_att_row.c_z
fastfold_block = EvoformerBlock(c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == len(model.evoformer.blocks) -
1))
copy_para(fastfold_block, openfold_block)
fastfold_blocks.append(fastfold_block)
model.evoformer.blocks = fastfold_blocks
return model
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import random
import sys
import time
from datetime import date
import numpy as np
import torch
import openfold.np.relax.relax as relax
from fastfold.utils import inject_openfold
from openfold.config import model_config
from openfold.data import data_pipeline, feature_pipeline, templates
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import protein, residue_constants
from openfold.utils.import_weights import import_jax_weights_
from openfold.utils.tensor_utils import tensor_tree_map
from scripts.utils import add_data_args
def main(args):
config = model_config(args.model_name)
model = AlphaFold(config)
import_jax_weights_(model, args.param_path, version=args.model_name)
model = inject_openfold(model)
model = model.eval()
#script_preset_(model)
model = model.to(args.model_device)
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=config.data.predict.max_templates,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=args.release_dates_path,
obsolete_pdbs_path=args.obsolete_pdbs_path)
use_small_bfd = (args.bfd_database_path is None)
data_processor = data_pipeline.DataPipeline(template_featurizer=template_featurizer,)
output_dir_base = args.output_dir
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
feature_processor = feature_pipeline.FeaturePipeline(config.data)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
if (args.use_precomputed_alignments is None):
alignment_dir = os.path.join(output_dir_base, "alignments")
else:
alignment_dir = args.use_precomputed_alignments
# Gather input sequences
with open(args.fasta_path, "r") as fp:
lines = [l.strip() for l in fp.readlines()]
tags, seqs = lines[::2], lines[1::2]
tags = [l[1:] for l in tags]
for tag, seq in zip(tags, seqs):
fasta_path = os.path.join(args.output_dir, "tmp.fasta")
with open(fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}")
print("Generating features...")
local_alignment_dir = os.path.join(alignment_dir, tag)
if (args.use_precomputed_alignments is None):
if not os.path.exists(local_alignment_dir):
os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
alignment_runner.run(fasta_path, local_alignment_dir)
feature_dict = data_processor.process_fasta(fasta_path=fasta_path,
alignment_dir=local_alignment_dir)
# Remove temporary FASTA file
os.remove(fasta_path)
processed_feature_dict = feature_processor.process_features(
feature_dict,
mode='predict',
)
print("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {k: torch.as_tensor(v, device=args.model_device) for k, v in batch.items()}
t = time.perf_counter()
out = model(batch)
print(f"Inference time: {time.perf_counter() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
plddt = out["plddt"]
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(plddt[..., None], residue_constants.atom_type_num, axis=-1)
unrelaxed_protein = protein.from_prediction(features=batch,
result=out,
b_factors=plddt_b_factors)
# Save the unrelaxed PDB.
unrelaxed_output_path = os.path.join(args.output_dir,
f'{tag}_{args.model_name}_unrelaxed.pdb')
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
)
# Relax the prediction.
t = time.perf_counter()
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if ("cuda" in args.model_device):
device_no = args.model_device.split(":")[-1]
os.environ["CUDA_VISIBLE_DEVICES"] = device_no
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
if visible_devices:
os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
print(f"Relaxation time: {time.perf_counter() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(args.output_dir, f'{tag}_{args.model_name}_relaxed.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_path",
type=str,
)
parser.add_argument(
"template_mmcif_dir",
type=str,
)
parser.add_argument("--use_precomputed_alignments",
type=str,
default=None,
help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""")
parser.add_argument(
"--output_dir",
type=str,
default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
)
parser.add_argument("--model_device",
type=str,
default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")""")
parser.add_argument("--model_name",
type=str,
default="model_1",
help="""Name of a model config. Choose one of model_{1-5} or
model_{1-5}_ptm, as defined on the AlphaFold GitHub.""")
parser.add_argument("--param_path",
type=str,
default=None,
help="""Path to model parameters. If None, parameters are selected
automatically according to the model name from
openfold/resources/params""")
parser.add_argument("--cpus",
type=int,
default=12,
help="""Number of CPUs with which to run alignment tools""")
parser.add_argument('--preset',
type=str,
default='reduced_dbs',
choices=('reduced_dbs', 'full_dbs'))
parser.add_argument('--data_random_seed', type=str, default=None)
add_data_args(parser)
args = parser.parse_args()
if (args.param_path is None):
args.param_path = os.path.join("openfold", "resources", "params",
"params_" + args.model_name + ".npz")
if (args.model_device == "cpu" and torch.cuda.is_available()):
logging.warning("""The model is being run on CPU. Consider specifying
--model_device for better performance""")
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment