Unverified Commit b19827f6 authored by dianyo's avatar dianyo Committed by GitHub
Browse files

Migrate the BrownianTree to BrownianInterval in DPM solver (#9335)



migrate the BrownianTree to BrownianInterval
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent c002731d
......@@ -38,7 +38,20 @@ class BatchedBrownianTree:
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
self.trees = [
torchsde.BrownianInterval(
t0=t0,
t1=t1,
size=w0.shape,
dtype=w0.dtype,
device=w0.device,
entropy=s,
tol=1e-6,
pool_size=24,
halfway_tree=True,
)
for s in seed
]
@staticmethod
def sort(a, b):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment