Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
dba44612
Commit
dba44612
authored
Apr 28, 2022
by
Gustaf Ahdritz
Browse files
Resolve merge conflicts
parents
4bd1b4d5
576174f0
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
544 additions
and
74 deletions
+544
-74
openfold/utils/kernel/csrc/softmax_cuda_kernel.cu
openfold/utils/kernel/csrc/softmax_cuda_kernel.cu
+241
-0
openfold/utils/loss.py
openfold/utils/loss.py
+6
-33
openfold/utils/lr_schedulers.py
openfold/utils/lr_schedulers.py
+1
-1
openfold/utils/superimposition.py
openfold/utils/superimposition.py
+29
-5
openfold/utils/validation_metrics.py
openfold/utils/validation_metrics.py
+32
-1
run_pretrained_openfold.py
run_pretrained_openfold.py
+12
-1
scripts/generate_mmcif_cache.py
scripts/generate_mmcif_cache.py
+1
-1
setup.py
setup.py
+66
-3
tests/test_import_weights.py
tests/test_import_weights.py
+1
-1
tests/test_kernels.py
tests/test_kernels.py
+84
-0
tests/test_loss.py
tests/test_loss.py
+17
-1
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_pair_transition.py
tests/test_pair_transition.py
+1
-1
train_openfold.py
train_openfold.py
+52
-25
No files found.
openfold/utils/kernel/csrc/softmax_cuda_kernel.cu
0 → 100644
View file @
dba44612
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
#include <math_constants.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__
__device__
float
WarpAllReduceMax
(
float
val
)
{
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
val
=
max
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
));
}
return
val
;
}
__inline__
__device__
float
WarpAllReduceSum
(
float
val
)
{
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
val
+=
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
);
}
return
val
;
}
template
<
typename
T
>
__global__
void
attn_softmax_inplace_
(
T
*
input
,
long
long
rows
,
int
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)(
blockIdx
.
x
*
4
+
threadidx_x
);
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
row_input
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
int
idx
=
lane_id
*
cols_per_thread
+
i
;
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
idx
]);
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
void
attn_softmax_inplace_forward_
(
at
::
Tensor
input
,
long
long
rows
,
int
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
attn_softmax_inplace_
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
);
}
else
{
attn_softmax_inplace_
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
);
}
}
template
<
typename
T
>
__global__
void
attn_softmax_inplace_grad_
(
T
*
output
,
T
*
d_ov
,
T
*
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)(
blockIdx
.
x
*
4
+
threadidx_x
);
int
cols_per_thread
=
(
cols_output
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
rows_values
=
cols_output
;
// values are set to the beginning of the current
// rows_values x cols_values leaf matrix
long
long
value_row_offset
=
row_offset
-
row_offset
%
rows_values
;
int
last_y
=
(
cols_output
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols_output
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_output
=
output
+
row_offset
*
cols_output
;
T
*
row_d_ov
=
d_ov
+
row_offset
*
cols_values
;
T
*
row_values
=
values
+
value_row_offset
*
cols_values
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
// Compute a chunk of the output gradient on the fly
int
value_row_idx
=
0
;
int
value_idx
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
T
sum
=
0.
;
#pragma unroll
for
(
int
j
=
0
;
j
<
cols_values
;
j
++
)
{
value_row_idx
=
((
lane_id
*
cols_per_thread
)
+
i
);
value_idx
=
value_row_idx
*
cols_values
+
j
;
sum
+=
row_d_ov
[
j
]
*
row_values
[
value_idx
];
}
dy_buf
[
i
]
=
static_cast
<
float
>
(
sum
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
y_buf
[
i
]
=
static_cast
<
float
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
}
float
thread_sum
=
0.
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
(
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]
);
}
}
}
void
attn_softmax_inplace_backward_
(
at
::
Tensor
output
,
at
::
Tensor
d_ov
,
at
::
Tensor
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
d_ov
);
CHECK_INPUT
(
values
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
attn_softmax_inplace_grad_
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
d_ov
.
data_ptr
(),
(
float
*
)
values
.
data_ptr
(),
rows
,
cols_output
,
cols_values
);
}
else
{
attn_softmax_inplace_grad_
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
d_ov
.
data_ptr
(),
(
at
::
BFloat16
*
)
values
.
data_ptr
(),
rows
,
cols_output
,
cols_values
);
}
}
openfold/utils/loss.py
View file @
dba44612
...
@@ -334,10 +334,12 @@ def supervised_chi_loss(
...
@@ -334,10 +334,12 @@ def supervised_chi_loss(
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
# The ol' switcheroo
sq_chi_error
=
sq_chi_error
.
permute
(
sq_chi_error
=
sq_chi_error
.
permute
(
*
range
(
len
(
sq_chi_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
*
range
(
len
(
sq_chi_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
)
sq_chi_loss
=
masked_mean
(
sq_chi_loss
=
masked_mean
(
chi_mask
[...,
None
,
:,
:],
sq_chi_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
chi_mask
[...,
None
,
:,
:],
sq_chi_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
)
...
@@ -1526,39 +1528,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1526,39 +1528,6 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
return
loss
return
loss
def
compute_drmsd
(
structure_1
,
structure_2
,
mask
=
None
):
if
(
mask
is
not
None
):
structure_1
=
structure_1
*
mask
[...,
None
]
structure_2
=
structure_2
*
mask
[...,
None
]
d1
=
structure_1
[...,
:,
None
,
:]
-
structure_1
[...,
None
,
:,
:]
d2
=
structure_2
[...,
:,
None
,
:]
-
structure_2
[...,
None
,
:,
:]
d1
=
d1
**
2
d2
=
d2
**
2
d1
=
torch
.
sqrt
(
torch
.
sum
(
d1
,
dim
=-
1
))
d2
=
torch
.
sqrt
(
torch
.
sum
(
d2
,
dim
=-
1
))
drmsd
=
d1
-
d2
drmsd
=
drmsd
**
2
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
n
=
d1
.
shape
[
-
1
]
if
mask
is
None
else
torch
.
sum
(
mask
,
dim
=-
1
)
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
drmsd
=
torch
.
sqrt
(
drmsd
)
return
drmsd
def
compute_drmsd_np
(
structure_1
,
structure_2
,
mask
=
None
):
structure_1
=
torch
.
tensor
(
structure_1
)
structure_2
=
torch
.
tensor
(
structure_2
)
if
(
mask
is
not
None
):
mask
=
torch
.
tensor
(
mask
)
return
compute_drmsd
(
structure_1
,
structure_2
,
mask
)
class
AlphaFoldLoss
(
nn
.
Module
):
class
AlphaFoldLoss
(
nn
.
Module
):
"""Aggregation of the various losses described in the supplement"""
"""Aggregation of the various losses described in the supplement"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -1627,6 +1596,10 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1627,6 +1596,10 @@ class AlphaFoldLoss(nn.Module):
weight
=
self
.
config
[
loss_name
].
weight
weight
=
self
.
config
[
loss_name
].
weight
loss
=
loss_fn
()
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
#for k,v in batch.items():
# if(torch.any(torch.isnan(v)) or torch.any(torch.isinf(v))):
# logging.warning(f"{k}: is nan")
#logging.warning(f"{loss_name}: {loss}")
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
cum_loss
=
cum_loss
+
weight
*
loss
...
...
openfold/utils/lr_schedulers.py
View file @
dba44612
...
@@ -17,7 +17,7 @@ class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
...
@@ -17,7 +17,7 @@ class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
base_lr
:
float
=
0.
,
base_lr
:
float
=
0.
,
max_lr
:
float
=
0.001
,
max_lr
:
float
=
0.001
,
warmup_no_steps
:
int
=
1000
,
warmup_no_steps
:
int
=
1000
,
start_decay_after_n_steps
:
int
=
1
0000
,
start_decay_after_n_steps
:
int
=
5
0000
,
decay_every_n_steps
:
int
=
50000
,
decay_every_n_steps
:
int
=
50000
,
decay_factor
:
float
=
0.95
,
decay_factor
:
float
=
0.95
,
):
):
...
...
openfold/utils/superimposition.py
View file @
dba44612
...
@@ -42,7 +42,7 @@ def _superimpose_single(reference, coords):
...
@@ -42,7 +42,7 @@ def _superimpose_single(reference, coords):
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
return
coords
.
new_tensor
(
superimposed
),
coords
.
new_tensor
(
rmsd
)
def
superimpose
(
reference
,
coords
):
def
superimpose
(
reference
,
coords
,
mask
):
"""
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
...
@@ -51,18 +51,42 @@ def superimpose(reference, coords):
...
@@ -51,18 +51,42 @@ def superimpose(reference, coords):
[*, N, 3] reference tensor
[*, N, 3] reference tensor
coords:
coords:
[*, N, 3] tensor
[*, N, 3] tensor
mask:
[*, N] tensor
Returns:
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
"""
def
select_unmasked_coords
(
coords
,
mask
):
return
torch
.
masked_select
(
coords
,
(
mask
>
0.
)[...,
None
],
).
reshape
(
-
1
,
3
)
batch_dims
=
reference
.
shape
[:
-
2
]
batch_dims
=
reference
.
shape
[:
-
2
]
flat_reference
=
reference
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_reference
=
reference
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_coords
=
coords
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_coords
=
coords
.
reshape
((
-
1
,)
+
reference
.
shape
[
-
2
:])
flat_mask
=
mask
.
reshape
((
-
1
,)
+
mask
.
shape
[
-
1
:])
superimposed_list
=
[]
superimposed_list
=
[]
rmsds
=
[]
rmsds
=
[]
for
r
,
c
in
zip
(
flat_reference
,
flat_coords
):
for
r
,
c
,
m
in
zip
(
flat_reference
,
flat_coords
,
flat_mask
):
superimposed
,
rmsd
=
_superimpose_single
(
r
,
c
)
r_unmasked_coords
=
select_unmasked_coords
(
r
,
m
)
superimposed_list
.
append
(
superimposed
)
c_unmasked_coords
=
select_unmasked_coords
(
c
,
m
)
rmsds
.
append
(
rmsd
)
superimposed
,
rmsd
=
_superimpose_single
(
r_unmasked_coords
,
c_unmasked_coords
)
# This is very inelegant, but idk how else to invert the masking
# procedure.
count
=
0
superimposed_full_size
=
torch
.
zeros_like
(
r
)
for
i
,
unmasked
in
enumerate
(
m
):
if
(
unmasked
):
superimposed_full_size
[
i
]
=
superimposed
[
count
]
count
+=
1
superimposed_list
.
append
(
superimposed_full_size
)
rmsds
.
append
(
rmsd
)
superimposed_stacked
=
torch
.
stack
(
superimposed_list
,
dim
=
0
)
superimposed_stacked
=
torch
.
stack
(
superimposed_list
,
dim
=
0
)
rmsds_stacked
=
torch
.
stack
(
rmsds
,
dim
=
0
)
rmsds_stacked
=
torch
.
stack
(
rmsds
,
dim
=
0
)
...
...
openfold/utils/validation_metrics.py
View file @
dba44612
...
@@ -14,16 +14,47 @@
...
@@ -14,16 +14,47 @@
import
torch
import
torch
def
drmsd
(
structure_1
,
structure_2
,
mask
=
None
):
def
prep_d
(
structure
):
d
=
structure
[...,
:,
None
,
:]
-
structure
[...,
None
,
:,
:]
d
=
d
**
2
d
=
torch
.
sqrt
(
torch
.
sum
(
d
,
dim
=-
1
))
return
d
d1
=
prep_d
(
structure_1
)
d2
=
prep_d
(
structure_2
)
drmsd
=
d1
-
d2
drmsd
=
drmsd
**
2
if
(
mask
is
not
None
):
drmsd
=
drmsd
*
(
mask
[...,
None
]
*
mask
[...,
None
,
:])
drmsd
=
torch
.
sum
(
drmsd
,
dim
=
(
-
1
,
-
2
))
n
=
d1
.
shape
[
-
1
]
if
mask
is
None
else
torch
.
sum
(
mask
,
dim
=-
1
)
drmsd
=
drmsd
*
(
1
/
(
n
*
(
n
-
1
)))
if
n
>
1
else
(
drmsd
*
0.
)
drmsd
=
torch
.
sqrt
(
drmsd
)
return
drmsd
def
drmsd_np
(
structure_1
,
structure_2
,
mask
=
None
):
structure_1
=
torch
.
tensor
(
structure_1
)
structure_2
=
torch
.
tensor
(
structure_2
)
if
(
mask
is
not
None
):
mask
=
torch
.
tensor
(
mask
)
return
drmsd
(
structure_1
,
structure_2
,
mask
)
def
gdt
(
p1
,
p2
,
mask
,
cutoffs
):
def
gdt
(
p1
,
p2
,
mask
,
cutoffs
):
n
=
torch
.
sum
(
mask
,
dim
=-
1
)
n
=
torch
.
sum
(
mask
,
dim
=-
1
)
p1
=
p1
.
float
()
p1
=
p1
.
float
()
p2
=
p2
.
float
()
p2
=
p2
.
float
()
distances
=
torch
.
sqrt
(
torch
.
sum
((
p1
-
p2
)
**
2
,
dim
=-
1
))
distances
=
torch
.
sqrt
(
torch
.
sum
((
p1
-
p2
)
**
2
,
dim
=-
1
))
scores
=
[]
scores
=
[]
for
c
in
cutoffs
:
for
c
in
cutoffs
:
score
=
torch
.
sum
((
distances
<=
c
)
*
mask
,
dim
=-
1
)
/
n
score
=
torch
.
sum
((
distances
<=
c
)
*
mask
,
dim
=-
1
)
/
n
score
=
torch
.
mean
(
score
)
scores
.
append
(
score
)
scores
.
append
(
score
)
return
sum
(
scores
)
/
len
(
scores
)
return
sum
(
scores
)
/
len
(
scores
)
...
...
run_pretrained_openfold.py
View file @
dba44612
...
@@ -234,7 +234,7 @@ def main(args):
...
@@ -234,7 +234,7 @@ def main(args):
# Relax the prediction.
# Relax the prediction.
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
visible_devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
,
default
=
""
)
if
(
"cuda"
in
args
.
model_device
):
if
(
"cuda"
in
args
.
model_device
):
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
device_no
=
args
.
model_device
.
split
(
":"
)[
-
1
]
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
device_no
...
@@ -249,6 +249,13 @@ def main(args):
...
@@ -249,6 +249,13 @@ def main(args):
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
with
open
(
relaxed_output_path
,
'w'
)
as
f
:
f
.
write
(
relaxed_pdb_str
)
f
.
write
(
relaxed_pdb_str
)
if
(
args
.
save_outputs
):
output_dict_path
=
os
.
path
.
join
(
args
.
output_dir
,
f
'
{
tag
}
_
{
args
.
model_name
}
_output_dict.pkl'
)
with
open
(
output_dict_path
,
"wb"
)
as
fp
:
pickle
.
dump
(
out
,
fp
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -283,6 +290,10 @@ if __name__ == "__main__":
...
@@ -283,6 +290,10 @@ if __name__ == "__main__":
automatically according to the model name from
automatically according to the model name from
openfold/resources/params"""
openfold/resources/params"""
)
)
parser
.
add_argument
(
"--save_outputs"
,
type
=
bool
,
default
=
False
,
help
=
"Whether to save all model outputs, including embeddings, etc."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--cpus"
,
type
=
int
,
default
=
4
,
"--cpus"
,
type
=
int
,
default
=
4
,
help
=
"""Number of CPUs with which to run alignment tools"""
help
=
"""Number of CPUs with which to run alignment tools"""
...
...
scripts/generate_mmcif_cache.py
View file @
dba44612
...
@@ -27,7 +27,7 @@ def parse_file(f, args):
...
@@ -27,7 +27,7 @@ def parse_file(f, args):
local_data
=
{}
local_data
=
{}
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
local_data
[
"release_date"
]
=
mmcif
.
header
[
"release_date"
]
chain_ids
,
seqs
=
mmcif
.
chain_to_seqres
.
items
()
chain_ids
,
seqs
=
list
(
zip
(
*
mmcif
.
chain_to_seqres
.
items
()
))
local_data
[
"chain_ids"
]
=
chain_ids
local_data
[
"chain_ids"
]
=
chain_ids
local_data
[
"seqs"
]
=
seqs
local_data
[
"seqs"
]
=
seqs
local_data
[
"no_chains"
]
=
len
(
chain_ids
)
local_data
[
"no_chains"
]
=
len
(
chain_ids
)
...
...
setup.py
View file @
dba44612
...
@@ -12,8 +12,46 @@
...
@@ -12,8 +12,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
setuptools
import
find_packages
import
os
from
setuptools
import
setup
from
setuptools
import
setup
,
Extension
,
find_packages
import
subprocess
from
torch.utils.cpp_extension
import
BuildExtension
,
CUDAExtension
,
CUDA_HOME
version_dependent_macros
=
[
'-DVERSION_GE_1_1'
,
'-DVERSION_GE_1_3'
,
'-DVERSION_GE_1_5'
,
]
extra_cuda_flags
=
[
'-std=c++14'
,
'-maxrregcount=50'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
]
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
cc_flag
=
[
'-gencode'
,
'arch=compute_70,code=sm_70'
]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
extra_cuda_flags
+=
cc_flag
setup
(
setup
(
name
=
'openfold'
,
name
=
'openfold'
,
...
@@ -25,7 +63,32 @@ setup(
...
@@ -25,7 +63,32 @@ setup(
url
=
'https://github.com/aqlaboratory/openfold'
,
url
=
'https://github.com/aqlaboratory/openfold'
,
packages
=
find_packages
(
exclude
=
[
"tests"
,
"scripts"
]),
packages
=
find_packages
(
exclude
=
[
"tests"
,
"scripts"
]),
include_package_data
=
True
,
include_package_data
=
True
,
package_data
=
{
""
:
[
"resources/stereo_chemical_props.txt"
]},
package_data
=
{
"openfold"
:
[
'utils/kernel/csrc/*'
],
""
:
[
"resources/stereo_chemical_props.txt"
]
},
ext_modules
=
[
CUDAExtension
(
name
=
"attn_core_inplace_cuda"
,
sources
=
[
"openfold/utils/kernel/csrc/softmax_cuda.cpp"
,
"openfold/utils/kernel/csrc/softmax_cuda_kernel.cu"
,
],
include_dirs
=
[
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
'openfold/utils/kernel/csrc/'
)
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:
(
[
'-O3'
,
'--use_fast_math'
]
+
version_dependent_macros
+
extra_cuda_flags
),
}
)],
cmdclass
=
{
'build_ext'
:
BuildExtension
},
install_requires
=
[
install_requires
=
[
'torch'
,
'torch'
,
'deepspeed'
,
'deepspeed'
,
...
...
tests/test_import_weights.py
View file @
dba44612
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
...
@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
)
][
1
].
transpose
(
-
1
,
-
2
)
][
1
].
transpose
(
-
1
,
-
2
)
),
),
model
.
evoformer
.
blocks
[
1
].
outer_product_mean
.
linear_1
.
weight
,
model
.
evoformer
.
blocks
[
1
].
core
.
outer_product_mean
.
linear_1
.
weight
,
),
),
]
]
...
...
tests/test_kernels.py
0 → 100644
View file @
dba44612
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
torch
import
unittest
from
openfold.model.primitives
import
_attention
from
openfold.utils.kernel.attention_core
import
attention_core
from
tests.config
import
consts
class
TestAttentionCore
(
unittest
.
TestCase
):
def
test_attention_core_forward
(
self
):
n_res
=
consts
.
n_res
h
=
consts
.
n_heads_extra_msa
n_seq
=
consts
.
n_extra
c
=
consts
.
c_e
dtype
=
torch
.
float32
q
=
torch
.
rand
([
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
).
cuda
()
k
=
torch
.
rand
([
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
).
cuda
()
v
=
torch
.
rand
([
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
).
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
[
n_seq
,
n_res
]).
cuda
()
mask_bias
=
(
1e9
*
mask
-
1
)[...,
None
,
None
,
:].
to
(
dtype
)
out_repro
=
attention_core
(
q
,
k
,
v
,
mask_bias
,
None
)
out_gt
=
_attention
(
q
,
k
,
v
,
[
mask_bias
])
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
-
out_gt
))
<
consts
.
eps
)
def
test_attention_core_backward
(
self
):
n_res
=
consts
.
n_res
h
=
consts
.
n_heads_extra_msa
n_seq
=
consts
.
n_extra
c
=
consts
.
c_e
dtype
=
torch
.
float32
q
=
torch
.
rand
(
[
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
k
=
torch
.
rand
(
[
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
v
=
torch
.
rand
(
[
n_seq
,
h
,
n_res
,
c
],
dtype
=
dtype
,
requires_grad
=
True
).
cuda
()
mask
=
torch
.
randint
(
0
,
2
,
[
n_seq
,
n_res
]).
cuda
()
mask_bias
=
(
1e9
*
mask
-
1
)[...,
None
,
None
,
:].
to
(
dtype
)
def
clone
(
t
):
t
=
t
.
clone
()
if
(
t
.
requires_grad
):
t
.
retain_grad
()
return
t
q_repro
=
clone
(
q
)
k_repro
=
clone
(
k
)
v_repro
=
clone
(
v
)
out_repro
=
attention_core
(
q_repro
,
k_repro
,
v_repro
,
mask_bias
,
None
)
loss_repro
=
torch
.
mean
(
out_repro
)
loss_repro
.
backward
()
q_gt
=
clone
(
q
)
k_gt
=
clone
(
k
)
v_gt
=
clone
(
v
)
out_gt
=
_attention
(
q_gt
,
k_gt
,
v_gt
,
[
mask_bias
]
)
loss_gt
=
torch
.
mean
(
out_gt
)
loss_gt
.
backward
()
pairs
=
zip
([
q_repro
,
k_repro
,
v_repro
],
[
q_gt
,
k_gt
,
v_gt
])
for
t_repro
,
t_gt
in
pairs
:
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
t_repro
.
grad
-
t_gt
.
grad
))
<
consts
.
eps
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_loss.py
View file @
dba44612
...
@@ -42,6 +42,7 @@ from openfold.utils.loss import (
...
@@ -42,6 +42,7 @@ from openfold.utils.loss import (
backbone_loss
,
backbone_loss
,
sidechain_loss
,
sidechain_loss
,
tm_loss
,
tm_loss
,
compute_plddt
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
...
@@ -226,6 +227,21 @@ class TestLoss(unittest.TestCase):
...
@@ -226,6 +227,21 @@ class TestLoss(unittest.TestCase):
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
torch
.
max
(
torch
.
abs
(
out_gt
[
k
]
-
out_repro
[
k
]))
<
consts
.
eps
)
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compute_plddt_compare
(
self
):
n_res
=
consts
.
n_res
logits
=
np
.
random
.
rand
(
n_res
,
50
)
out_gt
=
alphafold
.
common
.
confidence
.
compute_plddt
(
logits
)
out_gt
=
torch
.
tensor
(
out_gt
)
logits_t
=
torch
.
tensor
(
logits
)
out_repro
=
compute_plddt
(
logits_t
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
def
test_find_structural_violations
(
self
):
def
test_find_structural_violations
(
self
):
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -655,7 +671,7 @@ class TestLoss(unittest.TestCase):
...
@@ -655,7 +671,7 @@ class TestLoss(unittest.TestCase):
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_backbone_loss
(
self
):
def
test_backbone_loss
_compare
(
self
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_sm
=
config
.
model
.
heads
.
structure_module
c_sm
=
config
.
model
.
heads
.
structure_module
...
...
tests/test_outer_product_mean.py
View file @
dba44612
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
]
.
core
.
outer_product_mean
(
.
outer_product_mean
(
torch
.
as_tensor
(
msa_act
).
cuda
(),
torch
.
as_tensor
(
msa_act
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
...
...
tests/test_pair_transition.py
View file @
dba44612
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
...
@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
(
out_repro
=
(
model
.
evoformer
.
blocks
[
0
]
model
.
evoformer
.
blocks
[
0
]
.
core
.
pair_transition
(
.
pair_transition
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
chunk_size
=
4
,
chunk_size
=
4
,
...
...
train_openfold.py
View file @
dba44612
...
@@ -8,6 +8,7 @@ import os
...
@@ -8,6 +8,7 @@ import os
#os.environ["NODE_RANK"]="0"
#os.environ["NODE_RANK"]="0"
import
random
import
random
import
sys
import
time
import
time
import
numpy
as
np
import
numpy
as
np
...
@@ -27,16 +28,18 @@ from openfold.data.data_modules import (
...
@@ -27,16 +28,18 @@ from openfold.data.data_modules import (
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.callbacks
import
(
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
EarlyStoppingVerbose
,
)
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.
argparse_utils
import
remove_arguments
from
openfold.utils.
loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.l
os
s
import
AlphaFoldL
oss
,
lddt_ca
,
compute_drmsd
from
openfold.utils.l
r_scheduler
s
import
AlphaFoldL
RScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
from
openfold.utils.validation_metrics
import
(
drmsd
,
gdt_ts
,
gdt_ts
,
gdt_ha
,
gdt_ha
,
)
)
...
@@ -72,12 +75,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -72,12 +75,12 @@ class OpenFoldWrapper(pl.LightningModule):
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
)
if
(
train
):
if
(
train
):
self
.
log
(
self
.
log
(
f
"train/loss
_epoch"
,
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
loss_breakdown
[
"loss"
],
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
other_metrics
=
self
.
_compute_validation_metrics
(
...
@@ -116,16 +119,14 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -116,16 +119,14 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
self
.
ema
.
update
(
self
.
model
)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug (PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
self
.
cached_weights
=
self
.
model
.
state_dict
()
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
# Run the model
...
@@ -171,20 +172,20 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -171,20 +172,20 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
per_residue
=
False
,
)
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
compute_
drmsd
(
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
pred_coords_masked_ca
,
gt_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
if
(
superimposition_metrics
):
superimposed_pred
,
_
=
superimpose
(
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
)
gdt_ts_score
=
gdt_ts
(
gdt_ts_score
=
gdt_ts
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
...
@@ -193,6 +194,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -193,6 +194,7 @@ class OpenFoldWrapper(pl.LightningModule):
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
)
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
...
@@ -203,11 +205,23 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -203,11 +205,23 @@ class OpenFoldWrapper(pl.LightningModule):
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
)
->
torch
.
optim
.
Adam
:
# Ignored as long as a DeepSpeed optimizer is configured
# Ignored as long as a DeepSpeed optimizer is configured
return
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
self
.
model
.
parameters
(),
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
)
return
{
"optimizer"
:
optimizer
,
"lr_scheduler"
:
{
"scheduler"
:
lr_scheduler
,
"interval"
:
"step"
,
"name"
:
"AlphaFoldLRScheduler"
,
}
}
def
on_load_checkpoint
(
self
,
checkpoint
):
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
...
@@ -232,7 +246,7 @@ def main(args):
...
@@ -232,7 +246,7 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
# TorchScript components of the model
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
...
@@ -251,6 +265,8 @@ def main(args):
...
@@ -251,6 +265,8 @@ def main(args):
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
save_top_k
=-
1
,
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
...
@@ -300,7 +316,12 @@ def main(args):
...
@@ -300,7 +316,12 @@ def main(args):
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
else
:
strategy
=
None
strategy
=
None
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
default_root_dir
=
args
.
output_dir
,
default_root_dir
=
args
.
output_dir
,
...
@@ -487,9 +508,15 @@ if __name__ == "__main__":
...
@@ -487,9 +508,15 @@ if __name__ == "__main__":
'used.'
'used.'
)
)
)
)
parser
.
add_argument
(
"--_distillation_structure_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
"--_distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
# Disable the initial validation pass
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment