test_sample_time.py 566 Bytes
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import jax
import pymc as pm
import numpy as np
import pytensor as pt
pt.config.floatX = "float32"
np.random.seed(123) 
n =10000
X = np.random.randn(n) 
Y =3* X + np.random.randn(n)
# 定义PyMC3模型
with pm.Model() as model:
    alpha = pm.Normal('alpha', mu=0, sigma=10)     
    beta = pm.Normal('beta', mu=0, sigma=10)     
    sigma = pm.HalfNormal('sigma', sigma=1)
    mu = alpha + beta * X
    Y_obs = pm.Normal('Y_obs', mu=mu, sigma=sigma, observed=Y)
    trace = pm.sample(1000, nuts_sampler="numpyro",return_inferencedata=False)
print(pm.summary(trace))